diff --git a/.codecov.yml b/.codecov.yml index a3ed7f47..db64e296 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -4,7 +4,8 @@ coverage: patch: false project: default: - threshold: 50% + target: 90% + threshold: 5% comment: layout: "header" require_changes: false diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index c772b96d..d41fc409 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,12 +1,19 @@ -## Description +## Pull Request Summary Provide a brief description of the PR's purpose here. -## Todos +### Key changes Notable points that this PR has either accomplished or will accomplish. - - [ ] TODO 1 + - [ ] Change 1 -## Questions -- [ ] Question1 +### Questions + - [ ] Question 1 -## Status -- [ ] Ready to go \ No newline at end of file +### Associated Issue(s) + - [ ] Issue 1 + +## Pull Request Checklist + - [ ] Issue(s) raised/addressed and linked + - [ ] Includes appropriate unit test(s) + - [ ] Appropriate docstring(s) added/updated + - [ ] Appropriate .rst doc file(s) added/updated + - [ ] PR is ready for review \ No newline at end of file diff --git a/.github/workflows/CI.yaml b/.github/workflows/CI.yaml index 6b852a66..c8cf5c9e 100644 --- a/.github/workflows/CI.yaml +++ b/.github/workflows/CI.yaml @@ -35,6 +35,11 @@ jobs: matrix: os: [macOS-latest, ubuntu-latest] python-version: ["3.10", "3.11"] + include: + - os: ubuntu-latest + n_procs: auto + - os: macOS-latest + n_procs: 0 steps: - uses: actions/checkout@v3 @@ -73,12 +78,15 @@ jobs: - name: Run tests # conda setup requires this special shell run: | - pytest -v --cov=modelforge --cov-report=xml --color=yes --durations=10 modelforge/tests/ + pytest -n ${{matrix.n_procs}} --dist loadgroup -v --cov=modelforge --cov-report=xml --color=yes --durations=50 modelforge/tests/ - name: CodeCov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} if: ${{ github.event != 'schedule' }} # Don't upload results on scheduled runs with: + token: ${{ secrets.CODECOV_TOKEN }} file: ./coverage.xml flags: unittests name: codecov-${{ matrix.os }}-py${{ matrix.python-version }} diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml new file mode 100644 index 00000000..98f17e43 --- /dev/null +++ b/.github/workflows/black.yml @@ -0,0 +1,19 @@ +name: Lint + +on: + pull_request: + branches: + - main + push: + branches: + - main + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: psf/black@stable + with: + options: "--check --verbose --line-length 88" + src: "./modelforge" diff --git a/.gitignore b/.gitignore index c3e870b8..3c452056 100644 --- a/.gitignore +++ b/.gitignore @@ -182,61 +182,15 @@ $RECYCLE.BIN/ *.lnk # End of https://www.toptal.com/developers/gitignore/api/osx,windows,linux -qm9tut/ lightning_logs/ *.gz *.npz *.hdf5 -scripts/tb_logs/* -notebooks/tb_logs/training/version_0/events.out.tfevents.1712567395.fedora.24017.0 -notebooks/tb_logs/training/version_0/hparams.yaml -notebooks/tb_logs/training/version_0/checkpoints/epoch=0-step=210.ckpt -notebooks/tb_logs/training/version_1/events.out.tfevents.1712567945.fedora.24017.1 -notebooks/tb_logs/training/version_1/hparams.yaml -notebooks/tb_logs/training/version_1/checkpoints/epoch=4-step=1050.ckpt -notebooks/tb_logs/training/version_10/events.out.tfevents.1713300058.fedora.156712.0 -notebooks/tb_logs/training/version_10/hparams.yaml -notebooks/tb_logs/training/version_10/checkpoints/epoch=4-step=5.ckpt -notebooks/tb_logs/training/version_11/events.out.tfevents.1713300137.fedora.157775.0 -notebooks/tb_logs/training/version_11/hparams.yaml -notebooks/tb_logs/training/version_11/checkpoints/epoch=4-step=5.ckpt -notebooks/tb_logs/training/version_12/events.out.tfevents.1713300163.fedora.158807.0 -notebooks/tb_logs/training/version_12/hparams.yaml -notebooks/tb_logs/training/version_12/checkpoints/epoch=4-step=5.ckpt -notebooks/tb_logs/training/version_13/events.out.tfevents.1713300213.fedora.159869.0 -notebooks/tb_logs/training/version_13/hparams.yaml -notebooks/tb_logs/training/version_13/checkpoints/epoch=4-step=5.ckpt -notebooks/tb_logs/training/version_14/events.out.tfevents.1713300437.fedora.161554.0 -notebooks/tb_logs/training/version_14/hparams.yaml -notebooks/tb_logs/training/version_14/checkpoints/epoch=4-step=5.ckpt -notebooks/tb_logs/training/version_15/events.out.tfevents.1713301163.fedora.166274.0 -notebooks/tb_logs/training/version_15/hparams.yaml -notebooks/tb_logs/training/version_15/checkpoints/epoch=4-step=5.ckpt -notebooks/tb_logs/training/version_16/events.out.tfevents.1715177168.fedora.59344.0 -notebooks/tb_logs/training/version_16/events.out.tfevents.1715177661.fedora.59344.1 -notebooks/tb_logs/training/version_16/events.out.tfevents.1715177672.fedora.59344.2 -notebooks/tb_logs/training/version_16/events.out.tfevents.1715177678.fedora.59344.3 -notebooks/tb_logs/training/version_16/hparams.yaml -notebooks/tb_logs/training/version_2/events.out.tfevents.1712569996.fedora.33244.0 -notebooks/tb_logs/training/version_3/events.out.tfevents.1712570094.fedora.33244.1 -notebooks/tb_logs/training/version_3/hparams.yaml -notebooks/tb_logs/training/version_4/events.out.tfevents.1712570244.fedora.33244.2 -notebooks/tb_logs/training/version_4/hparams.yaml -notebooks/tb_logs/training/version_4/checkpoints/epoch=4-step=1050.ckpt -notebooks/tb_logs/training/version_5/events.out.tfevents.1713272484.fedora.65117.0 -notebooks/tb_logs/training/version_5/hparams.yaml -notebooks/tb_logs/training/version_5/checkpoints/epoch=4-step=5.ckpt -notebooks/tb_logs/training/version_6/events.out.tfevents.1713273578.fedora.65117.1 -notebooks/tb_logs/training/version_6/hparams.yaml -notebooks/tb_logs/training/version_6/checkpoints/epoch=99-step=100.ckpt -notebooks/tb_logs/training/version_7/events.out.tfevents.1713273988.fedora.65117.2 -notebooks/tb_logs/training/version_7/hparams.yaml -notebooks/tb_logs/training/version_7/checkpoints/epoch=4-step=5.ckpt -notebooks/tb_logs/training/version_8/events.out.tfevents.1713274608.fedora.85750.0 -notebooks/tb_logs/training/version_8/hparams.yaml -notebooks/tb_logs/training/version_8/checkpoints/epoch=4-step=5.ckpt -notebooks/tb_logs/training/version_9/events.out.tfevents.1713299697.fedora.154578.0 -notebooks/tb_logs/training/version_9/hparams.yaml -notebooks/tb_logs/training/version_9/checkpoints/epoch=4-step=5.ckpt +*/tb_logs/* +.vscode/settings.json +logs/* +cache/* +*/logs/* +*/cache/* diff --git a/.readthedocs.yaml b/.readthedocs.yaml index a1222759..3a52d3b9 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -10,7 +10,7 @@ sphinx: fail_on_warning: true conda: - environment: devtools/conda-envs/doc_env.yaml + environment: devtools/conda-envs/docs_env.yaml python: # Install our python package before building the docs diff --git a/MANIFEST.in b/MANIFEST.in index e0267afd..54419180 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,8 @@ include CODE_OF_CONDUCT.md +include modelforge/dataset/yaml_files/*.yaml +include modelforge/curation/yaml_files/*.yaml +include modelforge/tests/data/potential_defaults/*.toml +include modelforge/tests/data/training_defaults/*.toml + global-exclude *.py[cod] __pycache__ *.so diff --git a/README.md b/README.md index 3bb76a47..ca1d7f88 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ modelforge [//]: # (Badges) [![CI](https://github.com/choderalab/modelforge/actions/workflows/CI.yaml/badge.svg)](https://github.com/choderalab/modelforge/actions/workflows/CI.yaml) [![codecov](https://codecov.io/gh/choderalab/modelforge/branch/main/graph/badge.svg)](https://codecov.io/gh/choderalab/modelforge/branch/main) +[![Documentation Status](https://readthedocs.org/projects/modelforge/badge/?version=latest)](https://modelforge.readthedocs.io/en/latest/?badge=latest) [![Github release](https://badgen.net/github/release/choderalab/modelforge)](https://github.com/choderalab/modelforge/) [![GitHub license](https://img.shields.io/github/license/choderalab/modelforge?color=green)](https://github.com/choderalab/modelforge/blob/main/LICENSE) [![GitHub issues](https://img.shields.io/github/issues/choderalab/modelforge?style=flat)](https://github.com/choderalab/modelforge/issues) diff --git a/devtools/conda-envs/doc_env.yaml b/devtools/conda-envs/docs_env.yaml similarity index 84% rename from devtools/conda-envs/doc_env.yaml rename to devtools/conda-envs/docs_env.yaml index a623f393..17156b81 100644 --- a/devtools/conda-envs/doc_env.yaml +++ b/devtools/conda-envs/docs_env.yaml @@ -8,7 +8,7 @@ dependencies: - pip - h5py - tqdm - - qcelemental=0.25.1 + - qcelemental - qcportal>=0.50 - pytorch>=2.1 - loguru @@ -21,6 +21,8 @@ dependencies: - rdkit - retry - sqlitedict + - pydantic>=2 + - ray-all # Testing - pytest>=2.1 @@ -36,5 +38,6 @@ dependencies: - jax - flax - pytorch2jax + #- "ray[data,train,tune,serve]" - git+https://github.com/ArnNag/sake.git@nanometer - - "ray[data,train,tune,serve]" + - torchviz2 diff --git a/devtools/conda-envs/env.yaml b/devtools/conda-envs/env.yaml new file mode 100644 index 00000000..8535f561 --- /dev/null +++ b/devtools/conda-envs/env.yaml @@ -0,0 +1,27 @@ +name: modelforge_env +channels: + - conda-forge + - pytorch +dependencies: + # Base depends + - python + - pip + - h5py + - tqdm + - toml + - qcportal>=0.50 + - qcelemental + - pytorch>=2.1 + - loguru + - lightning>=2.0.8 + - tensorboard + - torchvision + - openff-units + - torchmetrics>=1.4 + - pint=0.23 + - rdkit + - retry + - sqlitedict + - pydantic>=2 + - ray-all + - jax diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index 3c13c51d..ae6027cb 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -8,8 +8,9 @@ dependencies: - pip - h5py - tqdm - - qcelemental=0.25.1 + - toml - qcportal>=0.50 + - qcelemental - pytorch>=2.1 - loguru - lightning>=2.0.8 @@ -17,10 +18,13 @@ dependencies: - torchvision - openff-units - torchmetrics>=1.4 - - pint + - pint=0.23 - rdkit - retry - sqlitedict + - pydantic>=2 + - ray-all + - graphviz # Testing - pytest>=2.1 @@ -29,11 +33,12 @@ dependencies: - requests - versioneer - # Docs - - sphinx_rtd_theme - - pip: + - jax + - flax - pytorch2jax - git+https://github.com/ArnNag/sake.git@nanometer - - "ray[data,train,tune,serve]" - flax + - torch + - pytest-xdist + diff --git a/devtools/conda-envs/test_env_mac.yaml b/devtools/conda-envs/test_env_mac.yaml index cf5b8b60..65b322d4 100644 --- a/devtools/conda-envs/test_env_mac.yaml +++ b/devtools/conda-envs/test_env_mac.yaml @@ -9,8 +9,8 @@ dependencies: - pip - h5py - tqdm - - qcelemental=0.25.1 - qcportal>=0.50 + - qcelemental - pytorch>=2.1 - loguru - lightning>=2.0.8 @@ -23,6 +23,9 @@ dependencies: - sqlitedict - jax - flax + - pydantic>=2.0 + - graphviz + - # Testing - pytest>=2.1 @@ -31,12 +34,10 @@ dependencies: - requests - versioneer - # Docs - - sphinx_rtd_theme - # pip installs - pip: - pytorch2jax - git+https://github.com/ArnNag/sake.git@nanometer - jax - flax + - pytest-xdist diff --git a/docs/_static/README.md b/docs/_static/README.md deleted file mode 100644 index 2f0cf843..00000000 --- a/docs/_static/README.md +++ /dev/null @@ -1,16 +0,0 @@ -# Static Doc Directory - -Add any paths that contain custom static files (such as style sheets) here, -relative to the `conf.py` file's directory. -They are copied after the builtin static files, -so a file named "default.css" will overwrite the builtin "default.css". - -The path to this folder is set in the Sphinx `conf.py` file in the line: -```python -templates_path = ['_static'] -``` - -## Examples of file to add to this directory -* Custom Cascading Style Sheets -* Custom JavaScript code -* Static logo images diff --git a/docs/_static/test.txt b/docs/_static/test.txt new file mode 100644 index 00000000..9daeafb9 --- /dev/null +++ b/docs/_static/test.txt @@ -0,0 +1 @@ +test diff --git a/docs/_templates/README.md b/docs/_templates/README.md deleted file mode 100644 index 3f4f8043..00000000 --- a/docs/_templates/README.md +++ /dev/null @@ -1,14 +0,0 @@ -# Templates Doc Directory - -Add any paths that contain templates here, relative to -the `conf.py` file's directory. -They are copied after the builtin template files, -so a file named "page.html" will overwrite the builtin "page.html". - -The path to this folder is set in the Sphinx `conf.py` file in the line: -```python -html_static_path = ['_templates'] -``` - -## Examples of file to add to this directory -* HTML extensions of stock pages like `page.html` or `layout.html` diff --git a/docs/api.rst b/docs/api.rst index 96e2fb5d..b308beaa 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -2,6 +2,11 @@ API Documentation ================= .. autosummary:: - :toctree: autosummary + :toctree: autosummary + :recursive: + + modelforge.potential + modelforge.dataset + modelforge.utils + modelforge.curation - modelforge.canvas diff --git a/docs/conf.py b/docs/conf.py index 9490fea8..73aa86cd 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -15,22 +15,25 @@ # Incase the project was not installed import os import sys -sys.path.insert(0, os.path.abspath('..')) + +sys.path.insert(0, os.path.abspath("..")) import modelforge # -- Project information ----------------------------------------------------- -project = 'modelforge' -copyright = ("2023, Marcus Wieder. Project structure based on the " - "Computational Molecular Science Python Cookiecutter version 1.1") -author = 'Marcus Wieder' +project = "modelforge" +copyright = ( + "2024, Chodera Lab. Project structure based on the " + "Computational Molecular Science Python Cookiecutter version 1.1" +) +author = "Marcus Wieder" # The short X.Y version -version = '' +version = "" # The full version, including alpha/beta/rc tags -release = '' +release = "" # -- General configuration --------------------------------------------------- @@ -43,13 +46,14 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autosummary', - 'sphinx.ext.autodoc', - 'sphinx.ext.mathjax', - 'sphinx.ext.viewcode', - 'sphinx.ext.napoleon', - 'sphinx.ext.intersphinx', - 'sphinx.ext.extlinks', + "sphinx.ext.autosummary", + "sphinx.ext.autodoc", + "sphinx.ext.mathjax", + "sphinx.ext.viewcode", + "sphinx.ext.napoleon", + "sphinx.ext.intersphinx", + "sphinx.ext.extlinks", + "sphinx.ext.autosectionlabel", ] autosummary_generate = True @@ -57,17 +61,21 @@ napoleon_use_param = False napoleon_use_ivar = True +autodoc_docstring_signature = True +autoclass_content = "both" # Add this line to combine class and __init__ docstrings + + # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -79,10 +87,10 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path . -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'default' +pygments_style = "default" # -- Options for HTML output ------------------------------------------------- @@ -90,7 +98,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -101,7 +109,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Custom sidebar templates, must be a dictionary that maps document names # to template names. @@ -117,7 +125,7 @@ # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. -htmlhelp_basename = 'modelforgedoc' +htmlhelp_basename = "modelforgedoc" # -- Options for LaTeX output ------------------------------------------------ @@ -126,15 +134,12 @@ # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # # 'preamble': '', - # Latex figure (float) alignment # # 'figure_align': 'htbp', @@ -144,8 +149,7 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'modelforge.tex', 'modelforge Documentation', - 'modelforge', 'manual'), + (master_doc, "modelforge.tex", "modelforge Documentation", "modelforge", "manual"), ] @@ -153,10 +157,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'modelforge', 'modelforge Documentation', - [author], 1) -] +man_pages = [(master_doc, "modelforge", "modelforge Documentation", [author], 1)] # -- Options for Texinfo output ---------------------------------------------- @@ -165,9 +166,15 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'modelforge', 'modelforge Documentation', - author, 'modelforge', 'Infrastructure to implement and train NNPs', - 'Miscellaneous'), + ( + master_doc, + "modelforge", + "modelforge Documentation", + author, + "modelforge", + "Infrastructure to implement and train NNPs", + "Miscellaneous", + ), ] diff --git a/docs/datasets.rst b/docs/datasets.rst new file mode 100644 index 00000000..d37e98d7 --- /dev/null +++ b/docs/datasets.rst @@ -0,0 +1,106 @@ +Datasets +=============== + +The dataset module in modelforge provides a comprehensive suite of functions and classes designed to retrieve, transform, and store quantum mechanics (QM) datasets from QCArchive. These datasets are delivered in a format compatible with `torch.utils.data.Dataset `_, facilitating the training of machine learning potentials. The module supports actions related to data storage, caching, retrieval, and the conversion of stored HDF5 files into PyTorch-compatible datasets for training purposes. + + +General Workflow +---------------- +The typical workflow to interact with public datasets includes the following steps: + +1. *Obtaining the Dataset*: Download the raw dataset from QCArchive or another source. +2. *Processing the Dataset*: Convert the raw data into a standardized format and store it in an HDF5 file with consistent naming conventions and units. +3. *Uploading and Updating*: Upload the processed dataset to Zenodo and update the retrieval link in the dataset implementation within *modelforge*. + +For more information on how units are handled within the dataset, available properties, and instructions on developing custom datasets for *modelforge*, please refer to the `developer documentation `_. + +Available Datasets +------------------ + +The following datasets are available for use with `modelforge`: + +- :py:class:`~modelforge.dataset.QM9Dataset` +- :py:class:`~modelforge.dataset.ANI1xDataset` +- :py:class:`~modelforge.dataset.ANI2xDataset` +- :py:class:`~modelforge.dataset.SPICE1Dataset` +- :py:class:`~modelforge.dataset.SPICE1_OPENFF` +- :py:class:`~modelforge.dataset.SPICE2Dataset` +- :py:class:`~modelforge.dataset.PhAlkEthOHDataset` + +These datasets encompass a variety of molecular structures and properties, providing robust training data for developing machine learning potentials. + +Postprocessing of dataset entries +----------------------------------- + +Two common postprocessing operations are performed for training machine learned potentials: + +- *Removing Self-Energies*: Self-energies are per-element offsets added to the + total energy of a system. These offsets are not useful for training + machine-learned potentials and can be removed to provide cleaner training + data. +- *Normalization and Scaling*: Normalize the energies and other properties to + ensure they are on a comparable scale, which can improve the stability and + performance of the machine learning model. Note that this is done when atomic energies are predicted, i.e. the atomic energy (`E_i`) is scaled using the atomic energy distribution obtained from the training dataset: `E_i = E_i_stddev * E_i_pred + E_i_mean`. + + +Interacting with the Dataset Module +----------------------------------- + +The dataset module provides a :class:`~modelforge.dataset.DataModule` class for +preparing and setting up datasets for training. Designed to integrate seamlessly +with PyTorch Lightning, the :class:`~modelforge.dataset.DataModule` class +provides a user-friendly interface for dataset preparation and loading. + +The following example demonstrates how to use the :class:`~modelforge.dataset.DataModule` class to prepare and set up a dataset for training: + + +.. code-block:: python + + from modelforge.dataset import DataModule + from modelforge.dataset.utils import RandomRecordSplittingStrategy + + dataset_name = "QM9" + splitting_strategy = RandomRecordSplittingStrategy() # split randomly on system level + batch_size = 64 + version_select = "latest" + remove_self_energies = True # remove the atomic self energies + regression_ase = False # use the atomic self energies provided by the dataset + + data_module = DataModule( + name=dataset_name, + splitting_strategy=splitting_strategy, + batch_size=batch_size, + version_select=version_select, + remove_self_energies=remove_self_energies, + regression_ase=regression_ase, + ) + + # Prepare the data (downloads, processes, and caches if necessary) + data_module.prepare_data() + + # Setup the data for training, validation, and testing + data_module.setup() + +.. _dataset-configuration: + +Dataset Configuration +------------------------------------ + +Dataset configuration in modelforge is typically managed using a TOML file. This configuration file is crucial during the training process as it overrides the default values specified in the :class:`~modelforge.dataset.DataModule` class, ensuring a flexible and customizable setup. + +Below is a minimal example of a dataset configuration for the QM9 dataset. + +.. literalinclude:: ../modelforge/tests/data/dataset_defaults/qm9.toml + :language: toml + :caption: QM9 Dataset Configuration + +.. warning:: + The ``version_select`` field in the example indicates the use of a small subset of the QM9 dataset. To utilize the full dataset, set this variable to ``latest``. + + +Explanation of fields in `qm9.toml`: + +- `dataset_name`: Specifies the name of the dataset. For this example, it is QM9. +- `number_of_worker`: Determines the number of worker threads for data loading. Increasing the number of workers can speed up data loading but requires more memory. +- `version_select`: Indicates the version of the dataset to use. In this example, it points to a small subset of the dataset for quick testing. To use the full QM9 dataset, set this variable to `latest`. + diff --git a/docs/for_developer.rst b/docs/for_developer.rst new file mode 100644 index 00000000..4a97af53 --- /dev/null +++ b/docs/for_developer.rst @@ -0,0 +1,115 @@ +For Developers +=============== + +.. note:: + + This section is intended for developers who want to extend the functionality of the `modelforge` package or who want to develop their own machine learned potentials. This section explains design decisions + and the structure of neural network potentials. + + + +How to deal with units +--------------------------------- + +All public APIs require explicit units for values that are not dimensionless. +The units are specified using the `openff.units` package, e.g.: + +.. code-block:: python + + from openff.units import unit + + # A length of 1.0 angstrom + length = 1.0 * unit.angstrom + + +Units are also provided in the TOML files. For example, the following TOML file specifies a maximum interaction radius of 5.1 angstrom: + +.. code-block:: toml + + maximum_interaction_radius = "5.1 angstrom" + + +Internally, when units are removed, we use the openmm units system +`here `_. + + +Base structure of machine learned potentials +------------------------------------------------- + +The base structure of machine-learned potentials is illustrated in the figure +below. This structure is implemented in the +:py:class:`~modelforge.potential.models.BaseNetwork` class class within the +:py:mod:`~modelforge.potential.models`` module. The +:py:class:`~modelforge.potential.models.BaseNetwork` class serves as a +comprehensive framework that encapsulates both the neighbor list calculation, +handled by the +:py:class:`~modelforge.potential.models.ComputeInteractingAtomPairs` class, and +the core neural network potential, represented by the +:py:class:`~modelforge.potential.models.CoreNetwork` class. + +The neighbor list calculation is a critical component that determines which pairs of atoms in a molecular system are close enough to interact. This calculation not only identifies the interacting atom pairs but also computes the distances and distance vectors between each pair. These computed distances and vectors are essential for accurately modeling interatomic interactions in the neural network. + +The :py:class:`~modelforge.potential.models.CoreNetwork` class is responsible for producing a variable number of scalar outputs, such as per-atom energies (`E_i``) and partial charges (`q_i``). Additionally, the :py:class:`~modelforge.potential.models.CoreNetwork` maintains and outputs a feature representation of the atoms before this representation is processed by the readout layers, which are specialized for each scalar property. The readout layers transform the intermediate atomic feature representations into the final properties of interest, such as energies or charges. + +To further process these outputs, the :py:class:`~modelforge.potential.models.Postprocessing` class is used. This class performs several reduction operations, including summing the per-atom energies to compute the total molecular energy. The Postprocessing module may also perform other critical operations, such as calculating molecular self-energy corrections, which are necessary for accurate total energy predictions. + + + +.. image:: image/overview_network.png + :width: 400 + :align: center + +At a high level, as depicted in the figure below, the inputs to the core network include the following: + +- **Pairwise atom indices (ij)**: These indices specify the atom pairs that interact. +- **Distances (d_ij)**: The scalar distance between each interacting atom pair. +- **Distance vectors (r_ij)**: The vector pointing from one atom in the pair to the other. +- **Atomic numbers (Z)**: The nuclear charge of each atom, which influences its chemical properties. +- **Total charge (Q)**: The net charge of the entire system, which can affect the electrostatic interactions. +- **Coordinates (R)**: The 3D spatial positions of the atoms. + +The output of the core network includes the atomic energies (`E_i``) and the scalar feature representations of these energies before they undergo final processing through the readout layers. Additionally, the output can include other per-atom properties and associated scalar or vectorial feature representations, depending on the specific capabilities and configuration of the network. + +The operations within the CoreNetwork are divided into two main modules: + +- **Representation Module**: This module is responsible for embedding the atomic numbers and generating features from the 3D coordinates. It effectively translates the raw atomic inputs into a format that the neural network can process. + +- **Interaction Module**: This module iteratively updates the atomic feature representations by learning from local, pairwise interactions between atoms. Over multiple iterations, the interaction module refines the atomic features to capture the complex dependencies between atoms in the molecular system. + + + +.. image:: image/overview_core_network.png + :width: 400 + :align: center + :alt: Alternative text + +These components work together to enable the accurate prediction of molecular properties from atomic and molecular inputs, forming the foundation of machine-learned potentials in computational chemistry and materials science. + +Contributing to the modelforge package +--------------------------------------- + +The `modelforge` package is an open-source project and we welcome contributions from the community. +In general, modelforge uses the `Fork & Pull `_ approach for contributions. +Any github user can fork the project and submit a pull request with their changes. +The modelforge team will review the pull request and merge it if it meets the project's standards. + +Before contributing changes or additions to *modelforge*, first open an issue on github to discuss the proposed changes. +This will help ensure that the changes are in line with the project's goals and that they are implemented in a way that +is consistent with the rest of the codebase. +This will also ensure that such proposed changes are not already being worked on by another contributor. + +When contributing to *modelforge*, please follow the guidelines below: + +- Open an issue on github to discuss the proposed changes before starting work. +- Fork the project and create a new branch for your changes. +- Use the `PEP8 `_ style guide for Python code. +- Use `black `_ for code formatting. +- Use Python type hints to improve code readability and maintainability. +- Include docstrings formatted using the `numpydoc `_ style. +- Include unit tests for new code, including both positive and negative tests when applicable, and test edge cases. +- Include/update documentation in the form of .rst files (in the docs folder). +- Ensure that the code is well-commented. +- If applicable, include examples of how to use the new code. +- For more complex additions (e.g., new NNPs) create a new wiki page to explain the underlying algorithms. + +For more information on contributing to open-source projects, see the `Open Source Guides `_. diff --git a/docs/getting_started.rst b/docs/getting_started.rst index f52a60f0..1a706031 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -1,4 +1,250 @@ Getting Started -=============== +########################## -This page details how to get started with modelforge. + +This page details how to get started with *modelforge*, a library designed for training and evaluating machine learning models for interatomic potentials. + + +Installation +************** + + +To ensure a clean environment for *modelforge*, we recommend creating a new conda environment using the provided environment YAML file located in the devtools directory of the *modelforge* repository. + + +.. code-block:: bash + + conda env create -f devtools/conda-envs/test_env.yml --name modelforge + +This command will create a new conda environment named *modelforge* with all necessary dependencies installed. Note, this package has currently been tested and validated with Python 3.10 and 3.11. + +Next, clone the source code from the GitHub repository: + + +.. code-block:: bash + + git clone https://github.com/choderalab/modelforge + +Navigate to the top level of the *modelforge* directory and install the package using pip: + +.. code-block:: bash + + pip install . + +.. note:: + Modelforge has currently been tested and validated with Python 3.10 and 3.11 on various Linux distributions. While it may work on other platforms, at this time we recommend using a Linux environment. + +Use Cases for *Modelforge* +**************************** + + +*Modelforge* may be a good fit for your project if you are interested in: + +- Training machine-learned interatomic potentials using PyTorch and PyTorch Lightning. +- Utilizing pre-trained models for inference tasks. +- Exporting trained models to Jax for accelerated computation. +- Investigating the impact of hyperparameters on model performance. +- Developing new machine learning models without managing dataset and training infrastructure. + + + +How to Use *Modelforge* +**************************** + +Training a Model +============================ + +Before training a model, consider the following: + +1. **Architecture**: Which neural network architecture do you want to use? +2. **Training Set**: What dataset will you use for training? +3. **Loss Function**: Which properties will you include in the loss function? + +*Modelforge* currently supports the following architectures: + +- Invariant architectures: + * SchNet + * ANI2x + +- Equivariant architectures: + * PaiNN + * PhysNet + * TensorNet + * SAKE + +These architectures can be trained on the following datasets (distributed via zenodo https://zenodo.org/communities/modelforge/ ): + +- Ani1x +- Ani2x +- PHALKETOH +- QM9 +- SPICE1 (/openff) +- SPICE2 + +By default, potentials predict the total energy and per-atom forces within a given cutoff radius and can be trained on energies and forces. + +.. note:: PaiNN and PhysNet can also predict partial charges and calculate long-range interactions. PaiNN can additionally use multipole expansions. These features will introduce additional terms to the loss function. + +In the following example, we will train a SchNet model on the ANI1x dataset with energies and forces as the target properties. TOML files are used to define the potential architecture, dataset, training routine, and runtime parameters. + +Defining the Potential ++++++++++++++++++++++++++++++++++++++++ + +The potential architecture and relevant parameters are defined in a TOML configuration file. Here is an example of a potential definition for a SchNet model. Note that we use 16 radial basis functions, a maximum interaction radius of 5.0 angstroms, and 16 filters. We use a `ShiftedSoftplus`` activation (the fully differentiable version of ReLu) function and featurize the atomic number of the atoms in the dataset. Finally, we normalize the per-atom energy and reduce the per-atom energy to the per-molecule energy (which will then be returned).. + + +.. code-block:: toml + + [potential] + potential_name = "SchNet" + + [potential.core_parameter] # Parameters defining the architecture of the model + number_of_radial_basis_functions = 16 + maximum_interaction_radius = "5.0 angstrom" + number_of_interaction_modules = 3 + number_of_filters = 16 + shared_interactions = false + + [potential.core_parameter.activation_function_parameter] + activation_function_name = "ShiftedSoftplus" + + [potential.core_parameter.featurization] # Parameters defining the embedding of the input data + properties_to_featurize = ['atomic_number'] + maximum_atomic_number = 101 + number_of_per_atom_features = 32 + + [potential.postprocessing_parameter] + [potential.postprocessing_parameter.per_atom_energy] + normalize = true + from_atom_to_system_reduction = true + keep_per_atom_property = true + +Defining the Dataset ++++++++++++++++++++++++++++++++++++++++ + +The following TOML file defines the ANI1x dataset, allowing users to specify a specific version, as well as parameters used by the torch dataloaders (num_workers and pin_memory): + +.. code-block:: toml + + [dataset] + dataset_name = "ANI1x" + version_select = "latest" + num_workers = 4 + pin_memory = true + + +Defining the Training Routine ++++++++++++++++++++++++++++++++++++++++ + + +The training TOML file includes the number of epochs, batch size, learning rate, logger, callback parameters, and other training parameters (including dataset splitting). +Each of these settings plays a crucial role in the training process. + +Here is an example of a training routine definition: + +.. code-block:: toml + + [training] + number_of_epochs = 2 # Total number of training epochs + remove_self_energies = true # Whether to remove self-energies from the dataset + batch_size = 128 # Number of samples per batch + lr = 1e-3 # Learning rate for the optimizer + monitor = "val/per_system_energy/rmse" # Metric to monitor for checkpointing + + + [training.experiment_logger] + logger_name = "wandb" # Logger to use for tracking the training process + + [training.experiment_logger.wandb_configuration] + save_dir = "logs" # Directory to save logs + project = "training_test" # WandB project name + group = "modelforge_nnps" # WandB group name + log_model = true # Whether to log the model in WandB + job_type = "training" # Job type for WandB logging + tags = ["modelforge", "v_0.1.0"] # Tags for WandB logging + notes = "testing training" # Notes for WandB logging + + [training.lr_scheduler] + frequency = 1 # Frequency of learning rate updates + mode = "min" # Mode for the learning rate scheduler (minimizing the monitored metric) + factor = 0.1 # Factor by which the learning rate will be reduced + patience = 10 # Number of epochs with no improvement after which learning rate will be reduced + cooldown = 5 # Number of epochs to wait before resuming normal operation after learning rate has been reduced + min_lr = 1e-8 # Minimum learning rate + threshold = 0.1 # Threshold for measuring the new optimum, to only focus on significant changes + threshold_mode = "abs" # Mode for the threshold (absolute or relative) + monitor = "val/per_system_energy/rmse" # Metric to monitor for learning rate adjustments + interval = "epoch" # Interval for learning rate updates (per epoch) + + [training.loss_parameter] + loss_components = ['per_system_energy', 'per_atom_force'] # Properties to include in the loss function + + [training.loss_parameter.weight] + per_system_energy = 0.999 # Weight for per molecule energy in the loss calculation + per_atom_force = 0.001 # Weight for per atom force in the loss calculation + + [training.early_stopping] + verbose = true # Whether to print early stopping messages + monitor = "val/per_system_energy/rmse" # Metric to monitor for early stopping + min_delta = 0.001 # Minimum change to qualify as an improvement + patience = 50 # Number of epochs with no improvement after which training will be stopped + + [training.splitting_strategy] + name = "random_record_splitting_strategy" # Strategy for splitting the dataset + data_split = [0.8, 0.1, 0.1] # Proportions for training, validation, and test sets + seed = 42 # Random seed for reproducibility + +Defining Runtime Variables ++++++++++++++++++++++++++++++++++++++++ + +To define various aspects of the compute environment, various runtime parameters can be set. Here is an example of a runtime variable definition: + +.. code-block:: toml + + [runtime] + save_dir = "lightning_logs" # Directory to save logs and checkpoints + experiment_name = "exp1" # Name of the experiment + local_cache_dir = "./cache" # Directory for caching data + accelerator = "cpu" # Type of accelerator to use (e.g., 'cpu' or 'gpu') + number_of_nodes = 1 # Number of nodes to use for distributed training + devices = 1 # Number of devices to use + checkpoint_path = "None" # Path to a checkpoint to resume training + simulation_environment = "PyTorch" # Simulation environment + log_every_n_steps = 50 # Frequency of logging steps + +All of the above TOML files can be passed invididually or combined into a single TOML file that defines the training run. Assuming the combined TOML file is called `training.toml`, start the training by passing the TOML file to the perform_training.py script. + + +.. code-block:: bash + + python scripts/perform_training.py + --condensed_config_path="training.toml" + +*modelforge* uses Pydantic to validate the TOML files, ensuring that all required fields are present and that the values are of the correct type before any expensive computational operations are performed. This validation process helps to catch errors early in the training process. If the TOML file is not valid, an error message will be displayed, indicating the missing or incorrect fields. + +Using a Pretrained Model +============================ +.. warning:: This feature is currently a work in progress. + + +All training runs performed with *modelforge* are logged in a `wandb` project. You can access this project via the following link : https://https://wandb.ai/modelforge_nnps/projects/latest. Using the wandb API, you can download the trained models and use them for inference tasks. + + +Investigating Hyperparameter Impact on Model Performance +======================================================== + +For each supported architecture, modelforge provides reasonable priors for hyperparameters. Hyperparameter optimization is conducted using `Ray`. + +.. autoclass:: modelforge.train.tuning.RayTuner + :noindex: + + +*Modelforge* offers the :py:class:`~modelforge.train.tuning.RayTuner` class, which facilitates the exploration of hyperparameter impacts on model performance given specific training and dataset parameters within a defined computational budget. + + + + + +.. toctree:: + :maxdepth: 2 + :caption: Contents: \ No newline at end of file diff --git a/docs/image/overview_core_network.png b/docs/image/overview_core_network.png new file mode 100644 index 00000000..3c8ba26f Binary files /dev/null and b/docs/image/overview_core_network.png differ diff --git a/docs/image/overview_core_network.svg b/docs/image/overview_core_network.svg new file mode 100644 index 00000000..1dfefbdb --- /dev/null +++ b/docs/image/overview_core_network.svg @@ -0,0 +1,708 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/image/overview_network.png b/docs/image/overview_network.png new file mode 100644 index 00000000..e788d408 Binary files /dev/null and b/docs/image/overview_network.png differ diff --git a/docs/image/overview_network.svg b/docs/image/overview_network.svg new file mode 100644 index 00000000..04b36c6f --- /dev/null +++ b/docs/image/overview_network.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 14b0ba64..3593aca3 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -6,11 +6,21 @@ Welcome to modelforge's documentation! ========================================================= +*Modelforge* is a Python package to build and train machine learned interatomic potentials using PyTorch and Lightning. It is designed to be modular and flexible, allowing for easy extension and customization. It provides access to popular datasets and models, making it trivial to get started with training and evaluation. + +The best way to get started is to read the :doc:`getting_started` guide, which oultines how to + .. toctree:: :maxdepth: 2 :caption: Contents: getting_started + potentials + datasets + training + inference + for_developer + tuning api @@ -20,4 +30,4 @@ Indices and tables * :ref:`genindex` * :ref:`modindex` -* :ref:`search` +* :ref:`search` \ No newline at end of file diff --git a/docs/inference.rst b/docs/inference.rst new file mode 100644 index 00000000..4009bb92 --- /dev/null +++ b/docs/inference.rst @@ -0,0 +1,46 @@ +Inference Mode +########################## + +Inference mode is a mode allows us to use the trained model to make predictions. Given that a key usage of the inference mode will be molecule simulation, more efficient schemes for calculating interacting pairs are needed. + +Neighborlists +------------------------------------------ +Currently, there are two neighborlist strategies implemented within modelforge for inference, the brute force neighbolist (:class:`~modelforge.potential.neighbors.NeighborlistBruteNsq`) and Verlet neighborlist (:class:`~modelforge.potential.NeighborlistVerletNsq`). Both neighborlists support both periodic and not periodic orthogonal boxes. + +The neighborlist strategy can be toggled during potential setup via the `inference_neighborlist_strategy` parameter passed to the :class:`~modelforge.potential.models.NeuralNetworkPotentialFactory`. The default is the Verlet neighborlist ("verlet"); brute can be set via "brute". + +Brute force neighborlist +^^^^^^^^^^^^^^^^^^^^^^^^ +The brute force neighborlist calculates the pairs within the interaction cutoff by considering all possible pairs each time called, via an order N^2 operation. Typically this approach should only be used for very system sizes, given the scaling; furthermore the N^2 approach used to generate this list utilizes a large amount of memory as the system size grows. + + + +Verlet neighborlist +^^^^^^^^^^^^^^^^^^^^^^^^ + +The Verlet neighborlist operates under the assumption that under short time windows, the local environment around a given particle does not change significantly. As such, information about this local environment can be reused between subsequent steps, eliminating the need for a costly build step. + +To do this, the local environment of a given particle is identified and saved in a list (e.g., we can call this the verlet list), using the criteria pair distance < cutoff + skin. The skin is a user modifiable distance that captures a region of space beyond the interaction cutoff. In the current implementation, this verlet list is generated using the same order N^2 approach as the brute for scheme. Again, because positions are correlated with time, we typically can avoid performing another order N^2 calculation for several timesteps. Steps in between rebuilds scale as order N*M, where M is the average number of neighbors (which is typically much less than N). In our implementation, the verlet list is automatically regenerated when any given particle moves more than skin/2 (since the last build), to ensure that interactions are not missed. + +Larger values of skin result in longer time periods between rebuilds, but also typically increase the number of calculations that need to be perform at each timestep (as M will typically be larger). As such, this value can have a significant impact on performance of this calculation. + +Note: Since this utilizes an N^2 computation within Torch, the memory footprint may be problematic as system size grows. A cell list based approach will be implemented in the future. + + +Load inference potential from training checkpoint +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To use the trained model for inference, the checkpoint file generated during training must be loaded. The checkpoint file contains the model's weights, optimizer state, and other training-related information. The `load_inference_model_from_checkpoint` function provides a convenient way to load the checkpoint file and generate an inference model. + +.. code-block:: python + + from modelforge.potential.models import load_inference_model_from_checkpoint + + inference_model = load_inference_model_from_checkpoint(checkpoint_file) + + + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + diff --git a/docs/potentials.rst b/docs/potentials.rst new file mode 100644 index 00000000..de24eee3 --- /dev/null +++ b/docs/potentials.rst @@ -0,0 +1,171 @@ +Potentials +=============== + +Potentials: Overview +------------------------ + +A potential in *modelforge* encapsulates all the operations required to map a +description of a molecular system (which includes the Cartesian coordinates, +atomic numbers, and total charge of a system (optionally also the spin state)) +to, at a minimum, its energy. Specifically, a potential takes as input a +:py:class:`~modelforge.dataset.dataset.NNPInput` dataclass and outputs a +dictionary of PyTorch tensors representing per-atom and/or per-molecule +properties, including per-molecule energies. A potential comprises three main +components: + +1. **Input Preparation Module** + (:class:`~modelforge.potential.model.InputPreparation`): Responsible for + generating the pair list, pair distances, and pair displacement vectors based + on atomic coordinates and the specified cutoff. This module processes the raw + input data into a format suitable for the neural network. + +2. **Core Model** (:class:`~modelforge.potential.model.CoreNetwork`): The neural + network containing the learnable parameters, which forms the core of the + potential. This module generates per-atom scalar and, optionally, tensor properties. The inputs to the core model + include atom pair indices, pair distances, pair displacement vectors, and atomic properties such as atomic numbers, charges, and spin state. + +3. **Postprocessing Module** + (:class:`~modelforge.potential.model.PostProcessing`): Contains operations + applied to per-atom properties as well as reduction operations to obtain + per-molecule properties. Examples include atomic energy scaling and summation of per-atom energies to obtain per-molecule energies for reduction operations. + +A specific neural network (e.g., PhysNet) implements the core model, while the +input preparation and postprocessing modules are independent of the neural +network architecture. + + +Implemented Models +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +*Modelforge* currently supports the following potentials: + +- Invariant architectures: + * SchNet + * ANI2x +- Equivariant architectures: + * PaiNN + * PhysNet + * TensorNet + * SAKE + +Additionally, the following models are currently under development and can be expected in the near future: + +- SpookyNet +- DimeNet +- AimNet2 + +Each potential currently implements the total energy prediction with per-atom +forces within a given cutoff radius. The models can be trained on energies and +forces. PaiNN and PhysNet can also predict partial charges and +calculate long-range interactions using reaction fields or Coulomb potential. +PaiNN can additionally use multipole expansions. + +Using TOML files to configure potentials +-------------------------------------------- + +To initialize a potential, a potential factory is used: +:class:`~modelforge.potential.potential.NeuralNetworkPotentialFactory`. +This takes care of initialization of the potential and the input preparation and postprocessing modules. + +A neural network potential is defined by a configuration file in the TOML format. +This configuration file includes parameters for the neural network, as well as +for the input preparation and postprocessing modules. Below is an example +configuration file for the PhysNet model: + +.. literalinclude:: ../modelforge/tests/data/potential_defaults/physnet.toml + :language: toml + :caption: PhysNet Configuration + +There are two main sections in the configuration file: `core_parameter` and +`postprocessing_parameter`. The `core_parameter` section contains the parameters +for the neural network, while the `postprocessing_parameter` section contains +the parameters for the postprocessing operations. Explanation of fields in +`physnet.toml`: + +* `potential_name`: Specifies the type of potential to use, in this case, PhysNet. +* `number_of_radial_basis_functions`: Number of radial basis functions. +* `maximum_interaction_radius`: Cutoff radius for considering neighboring atoms. +* `number_of_interaction_residual`: PhysNet hyperparamter defining the depth of the network. +* `number_of_modules`: PhysNet hyperparamter defining the depth of the network;which scales with (number_of_interaction_residual * number_of_modules). +* `activation_function_name`: Activation function used in the neural network. +* `properties_to_featurize`: List of properties to featurize. +* `maximum_atomic_number```: Maximum atomic number in the dataset. +* `number_of_per_atom_features`: Number of features for each atom used for the embedding. This is the number of features that are used to represent each atom in the neural network. +* `normalize`: Whether to normalize energies for training. If this is set to true the mean and standard deviation of the energies are calculated and used to normalize the energies. +* `from_atom_to_system_reduction`: Whether to reduce the per-atom properties to per-molecule properties. +* `keep_per_atom_property`: If this is set to true the per-atom energies are returned as well. + + +Default parameter files for each potential are available in `modelforge/tests/data/potential_defaults`. These files can be used as starting points for creating new potential configuration files. + +.. note:: All parameters in the configuration files have units attached where applicable. Units within modelforge a represented using the `openff.units` package (https://docs.openforcefield.org/projects/units/en/stable/index.html), which is a wrapper around the `pint` package. Definition of units within the TOML files must unit names available in the `openff.units` package (https://github.com/openforcefield/openff-units/blob/main/openff/units/data/defaults.txt). + +Use cases of the factory class +-------------------------------------------- + +There are three main use cases of the :class:`~modelforge.potential.potential.NeuralNetworkPotentialFactory`: + +1. Create and train a model, then save the state_dict of its potential. Load the state_dict to the potential of an existing trainer (with defined hyperparameters) to resume training. + +2. Load a potential for inference from a saved state_dict. + +3. Load an inference from a checkpoint file. + +.. note:: The general idea to handle these use cases is that always call `generate_trainer()` to create or load a trainer; use `generate_potential()` for loading inference potential (this is also how `load_inference_model_from_checkpoint()` is implemented). + +.. code-block:: python + :linenos: + + # Use case 1 + trainer = NeuralNetworkPotentialFactory.generate_trainer( + potential_parameter=config["potential"], + training_parameter=config["training"], + runtime_parameter=config["runtime"], + dataset_parameter=config["dataset"], + ) + torch.save(trainer.lightning_module.state_dict(), file_path) + trainer2 = NeuralNetworkPotentialFactory.generate_trainer( + potential_parameter=config2["potential"], + training_parameter=config2["training"], + runtime_parameter=config2["runtime"], + dataset_parameter=config2["dataset"], + ) + trainer2.lightning_module.load_state_dict(torch.load(file_path)) + + # Use case 2 + potential = NeuralNetworkPotentialFactory.generate_potential( + simulation_environment="PyTorch", + potential_parameter=config["potential"], + ) + potential.load_state_dict(torch.load(file_path)) + + # Use case 3 + from modelforge.potential.potential import load_inference_model_from_checkpoint + potential = load_inference_model_from_checkpoint(ckpt_file) + +Example +------------------------------------ + +Below is an example of how to create a SchNet model using the potential factory, although we note these operations do not typically need to be performed directly by a user and are handled by routines available in the training module: + +.. code-block:: python + :linenos: + + model_name = "SchNet" + + # reading default parameters + from modelforge.tests.data import potential_defaults + from importlib import + import toml + + filename = ( + resources.files(potential_defaults) / f"{model_name.lower()}_defaults.toml" + ) + potential_parameters = toml.load(filename) + + # initialize the models with the given parameter set + model = NeuralNetworkPotentialFactory.generate_potential( + use="inference", + model_type=model_name, + model_parameter=potential_parameters, + ) \ No newline at end of file diff --git a/docs/training.rst b/docs/training.rst new file mode 100644 index 00000000..dbb756d6 --- /dev/null +++ b/docs/training.rst @@ -0,0 +1,122 @@ +Training +======== + +Training: Overview +------------------------------------------ + +During training the parameters of a potential are fitted to reproduce the target properties provided by a dataset. These are typically energies and forces, but can also be other properties provided by a dataset (e.g., dipole moments). + +The properties a given potential can be trained on are determined by the potential itself and by the loss function used to train the potential. Each of the models implemented in *modelforge* have flexible numbers of output heads, each of which can be fitted against a different scalar/tensor property. + +*Modelforge* uses Pytorch Lightning to train models. The training process is controlled by a :class:`~modelforge.train.training.PotentialTrainer` object, which is responsible for managing the training loop, the optimizer, the learning rate scheduler, and the early stopping criteria. The training process is controlled by a configuration file, `training.toml`, which specifies the number of epochs, the learning rate, the loss function, and the splitting strategy for the dataset. The training process can be started by + + +Training Configuration +------------------------------------------ + +The training process is controlled by a configuration file, typically written in TOML format. This file includes various sections for configuring different aspects of the training, such as logging, learning rate scheduling, loss functions, dataset splitting strategies, and early stopping criteria. + + +.. literalinclude:: ../modelforge/tests/data/training_defaults/default.toml + :language: toml + :caption: Training Configuration + +The TOML files are split into four categories: `potential`, `training`, +`runtime` and `dataset`. While it is possible to read in a single TOML file +defining fields for all three categories, often it is useful to read them in +separately (a common use case where this is useful is when a potential is trained on +different datasets — instead of repeating the `training` and `potential` +sections in each TOML file, only the `dataset.toml` file needs to be changed). + +Learning rate scheduler +^^^^^^^^^^^^^^^^^^^^^^^^ + +The learning rate scheduler is responsible for adjusting the learning rate during training. *Modelforge* uses the `REduceLROnPlateau scheduler`, which reduces the learning rate by a factor when the RMSE of the energy prediction on the validation set does not improve for a given number of epochs. The scheduler is controlled by the parameters in the `[training.lr_scheduler]` section of the `training.toml` file. + +Loss function +^^^^^^^^^^^^^^^^^^^^^^^^ +The loss function quantifies the discrepancy between the model's predictions and the target properties, providing a scalar value that guides the optimizer in updating the model's parameters. This function is configured in the `[training.loss]` section of the training TOML file. + +Depending on the specified `loss_components`` section, the loss function can combine various individual loss functions. *Modelforge* always includes the mean squared error (MSE) for energy prediction, and may also incorporate MSE for force prediction, dipole moment prediction, and partial charge prediction. + +The design of the loss function is intrinsically linked to the structure of the energy function. For instance, if the energy function aggregates atomic energies, then loss_components should include `per_system_energy` and optionally, `per_atom_force`. + + +Predicting Short-Range Atomic Energies +************************************************************ + +If the total atomic energy is calculated with a short cutoff radius, we can directly match the sum of the atomic energies `E_i` (as predicted by the potential model architecture) to the total energy of the molecule (provided by the dataset). + +In that case the total energy is calculated as + +.. math:: E = \sum_i^N E_i + +The loss function can then be formulated as: + +.. math:: L(E) = (E - E^{pred})^2 + +Alternatively, if the dataset includes per atom forces, the loss function can be expressed as: + +.. math:: L(E,F) = w_E * (E - E^{pred})^2 + w_F \frac{1}{3N} \sum_i^N \sum_j^3 (F_{ij} - F_{ij}^{pred})^2 + +where `w_E` and `w_F` are the weights for the energy and force components, respectively. + +Predicting Short-Range Atomic Energy with Long-Range Interactions +************************************************************************* + +.. warning:: + The following section is under development and may not be fully implemented in the current version of *modelforge*. + +In scenarios where long-range interactions are considered, additional terms are incorporated into the loss function. There are two long range interactions that are of interest: long range dispersion interactions and electrostatics. + +To calculate long range electrostatics the first moment of the charge density (partial charges) is predicted by the machine learning potential. The atomic +charges are then used to calculate the long range electrostatics. The expression +for the total energy is then: + +.. math:: E = \sum_i^N E_i + k_c \sum_i^N \sum_{j>i}^N \frac{q_i q_j}{r_{ij}} + +where `k_c` is the Coulomb constant, `q_i` and `q_j` are the atomic charges, and the loss function is: + +.. math:: L(E,F,Q) = L(E,F) + w_Q (\sum_i^N q_i - Q_i^{pred})^2 + \frac{w_p}{3} \sum_j^3 (\sum_i^N q_i r_i,j - p_j^{ref})^2 + +where `w_Q` is the weight for the charge component, `w_p` the weight for the dipole moment component, and `p_j^{ref}` is the reference dipole moment. + + +Splitting Strategies +^^^^^^^^^^^^^^^^^^^^^^^^ + +The dataset splitting strategy is crucial for ensuring that a potential generalizes well to unseen data. The recommended approach in *modelforge* is to randomly split the dataset into 80% training, 10% validation, and 10% test sets, based on molecules rather than individual conformations. This ensures that different conformations of the same molecule are consistently assigned to the same split, preventing data leakage and ensuring robust model evaluation. + + +*Modelforge* also provides other splitting strategies, including: + +- :class:`~modelforge.dataset.utils.FirstComeFirstServeStrategy`: Splits the dataset based on the order of records (molecules). +- :class:`~modelforge.dataset.utils.RandomSplittingStrategy`: Splits the dataset randomly based on conformations. + +To use a different data split ratio, you can specify a custom split list in the +splitting strategy. The most effective way to pass this information to the +training process is by defining the appropriate fields in the TOML file providing the training parameters, see the TOML file above. + + +Train a Model +---------------------- + +The best way to get started is to train a model. We provide a script to +train one of the implemented models on a dataset using default configurations. +The script, along with a default configuration TOML file, can be found in the `scripts`` directory. The TOML file provides parameters to train a `TensorNet` potential on the `QM9` dataset on a single GPU. + +The recommended method for training models is through the `perform_training.py`` script. This script automates the training process by reading the TOML configuration files, which define the potential (model), dataset, training routine, and runtime environment. The script then initializes a Trainer class from PyTorch Lightning, which handles the training loop, gradient updates, and evaluation metrics. During training, the model is optimized on the specified dataset according to the defined potential and training routine. After training completes, the script outputs performance metrics on the validation and test sets, providing insights into the model's accuracy and generalization. + +Additionally, the script saves the trained model with checkpoint files in the directory specified by the `[save_dir]`` field in the runtime section of the TOML file, enabling you to resume training or use the potential model for inference later. + +To initiate the training process, execute the following command in your terminal (inside the `scripts` directory): + +.. code-block:: bash + + python perform_training.py + --condensed_config_path="config.toml" + --accelerator="gpu" + --device=1 + +This command specifies the path to the configuration file, selects GPU acceleration, and designates the device index to be used during training. Adjust these parameters as necessary to fit your computational setup and training needs. + diff --git a/docs/tuning.rst b/docs/tuning.rst new file mode 100644 index 00000000..f4a58fd4 --- /dev/null +++ b/docs/tuning.rst @@ -0,0 +1,7 @@ +tuning module +============= + +.. automodule:: modelforge.train.tuning + :members: + :undoc-members: + :show-inheritance: diff --git a/modelforge/curation/ani1x_curation.py b/modelforge/curation/ani1x_curation.py index edab841a..382825ec 100644 --- a/modelforge/curation/ani1x_curation.py +++ b/modelforge/curation/ani1x_curation.py @@ -83,6 +83,8 @@ def _init_dataset_parameters(self): self.dataset_md5_checksum = data_inputs[self.version_select][ "dataset_md5_checksum" ] + self.dataset_length = data_inputs[self.version_select]["dataset_length"] + self.dataset_filename = data_inputs[self.version_select]["dataset_filename"] logger.debug( f"Dataset: {self.version_select} version: {data_inputs[self.version_select]['version']}" ) @@ -407,26 +409,27 @@ def process( "max_records and total_conformers cannot be set at the same time." ) - from modelforge.utils.remote import download_from_figshare + from modelforge.utils.remote import download_from_url url = self.dataset_download_url # download the dataset - self.name = download_from_figshare( + download_from_url( url=url, md5_checksum=self.dataset_md5_checksum, output_path=self.local_cache_dir, + output_filename=self.dataset_filename, + length=self.dataset_length, force_download=force_download, ) self._clear_data() # process the rest of the dataset - if self.name is None: - raise Exception("Failed to retrieve name of file from figshare.") + self._process_downloaded( self.local_cache_dir, - self.name, + self.dataset_filename, max_records=max_records, max_conformers_per_record=max_conformers_per_record, total_conformers=total_conformers, diff --git a/modelforge/curation/ani2x_curation.py b/modelforge/curation/ani2x_curation.py index ccddf0ea..09ecf9e9 100644 --- a/modelforge/curation/ani2x_curation.py +++ b/modelforge/curation/ani2x_curation.py @@ -69,6 +69,9 @@ def _init_dataset_parameters(self) -> None: self.dataset_md5_checksum = data_inputs[self.version_select][ "dataset_md5_checksum" ] + self.dataset_filename = data_inputs[self.version_select]["dataset_filename"] + self.dataset_length = data_inputs[self.version_select]["dataset_length"] + logger.debug( f"Dataset: {self.version_select} version: {data_inputs[self.version_select]['version']}" ) @@ -291,21 +294,20 @@ def process( "max_records and total_conformers cannot be set at the same time." ) - from modelforge.utils.remote import download_from_zenodo + from modelforge.utils.remote import download_from_url url = self.dataset_download_url # download the dataset - self.name = download_from_zenodo( + download_from_url( url=url, md5_checksum=self.dataset_md5_checksum, output_path=self.local_cache_dir, + output_filename=self.dataset_filename, + length=self.dataset_length, force_download=force_download, ) - if self.name is None: - raise Exception("Failed to retrieve name of file from Zenodo.") - # clear any data that might be present so we don't append to it self._clear_data() @@ -314,13 +316,13 @@ def process( extract_tarred_file( input_path_dir=self.local_cache_dir, - file_name=self.name, + file_name=self.dataset_filename, output_path_dir=self.local_cache_dir, mode="r:gz", ) # the untarred file will be in a directory named 'final_h5' within the local_cache_dir, - hdf5_filename = f"{self.name.replace('.tar.gz', '')}.h5" + hdf5_filename = f"{self.dataset_filename.replace('.tar.gz', '')}.h5" # process the rest of the dataset self._process_downloaded( diff --git a/modelforge/curation/phalkethoh_curation.py b/modelforge/curation/phalkethoh_curation.py index 600ef8ca..8355ff48 100644 --- a/modelforge/curation/phalkethoh_curation.py +++ b/modelforge/curation/phalkethoh_curation.py @@ -111,7 +111,7 @@ def _init_record_entries_series(self): "name": "single_rec", "dataset_name": "single_rec", "source": "single_rec", - "total_charge": "single_rec", + "total_charge": "single_atom", "atomic_numbers": "single_atom", "n_configs": "single_rec", "molecular_formula": "single_rec", @@ -267,7 +267,7 @@ def _calculate_total_charge( rdmol = Chem.MolFromSmiles(smiles, sanitize=False) total_charge = sum(atom.GetFormalCharge() for atom in rdmol.GetAtoms()) - return (int(total_charge) * unit.elementary_charge,) + return int(total_charge) * unit.elementary_charge def _process_downloaded( self, @@ -277,6 +277,8 @@ def _process_downloaded( max_conformers_per_record: Optional[int] = None, total_conformers: Optional[int] = None, atomic_numbers_to_limit: Optional[List[int]] = None, + max_force: Optional[unit.Quantity] = None, + final_conformer_only: Optional[bool] = None, ): """ Processes a downloaded dataset: extracts relevant information. @@ -295,6 +297,11 @@ def _process_downloaded( If set, this will limit the total number of conformers to the specified number. atomic_numbers_to_limit: Optional[List[int]], optional, default=None If set, this will limit the dataset to only include molecules with atomic numbers in the list. + max_force: Optional[float], optional, default=None + If set, this will exclude any conformers with a force that exceeds this value. + final_conformer_only: Optional[bool], optional, default=None + If set to True, only the final conformer of each record will be processed. This should be the final + energy minimized conformer. """ from tqdm import tqdm import numpy as np @@ -358,7 +365,7 @@ def _process_downloaded( ] data_temp["n_configs"] = 0 - (data_temp["total_charge"],) = self._calculate_total_charge( + data_temp["total_charge"] = self._calculate_total_charge( data_temp[ "canonical_isomeric_explicit_hydrogen_mapped_smiles" ] @@ -377,105 +384,123 @@ def _process_downloaded( name = key index = self.molecule_names[name] + if final_conformer_only: + trajectory = [trajectory[-1]] for state in trajectory: + add_record = True properties, config = state - self.data[index]["n_configs"] += 1 - - # note, we will use the convention of names being lowercase - # and spaces denoted by underscore - quantity = "geometry" - quantity_o = "geometry" - if quantity_o not in self.data[index].keys(): - self.data[index][quantity_o] = config.reshape(1, -1, 3) - else: - self.data[index][quantity_o] = np.vstack( - ( - self.data[index][quantity_o], - config.reshape(1, -1, 3), + + # if set, let us see if the configuration has a force that exceeds the maximum + if max_force is not None: + force_magnitude = ( + np.abs( + properties["properties"]["current gradient"] + + properties["properties"][ + "dispersion correction gradient" + ] ) + * self.qm_parameters["dft_total_force"]["u_in"] ) + if np.any(force_magnitude > max_force): + add_record = False + if add_record: + self.data[index]["n_configs"] += 1 + + # note, we will use the convention of names being lowercase + # and spaces denoted by underscore + quantity = "geometry" + quantity_o = "geometry" + if quantity_o not in self.data[index].keys(): + self.data[index][quantity_o] = config.reshape(1, -1, 3) + else: + self.data[index][quantity_o] = np.vstack( + ( + self.data[index][quantity_o], + config.reshape(1, -1, 3), + ) + ) - # note, we will use the convention of names being lowercase - # and spaces denoted by underscore - quantity = "current energy" - quantity_o = "dft_total_energy" - if quantity_o not in self.data[index].keys(): - self.data[index][quantity_o] = properties["properties"][ - quantity - ] - else: - self.data[index][quantity_o] = np.vstack( - ( - self.data[index][quantity_o], - properties["properties"][quantity], + # note, we will use the convention of names being lowercase + # and spaces denoted by underscore + quantity = "current energy" + quantity_o = "dft_total_energy" + if quantity_o not in self.data[index].keys(): + self.data[index][quantity_o] = properties["properties"][ + quantity + ] + else: + self.data[index][quantity_o] = np.vstack( + ( + self.data[index][quantity_o], + properties["properties"][quantity], + ) ) - ) - quantity = "dispersion correction energy" - quantity_o = "dispersion_correction_energy" - # Note need to typecast here because of a bug in the - # qcarchive entry: see issue: https://github.com/MolSSI/QCFractal/issues/766 - if quantity_o not in self.data[index].keys(): - self.data[index][quantity_o] = np.array( - float(properties["properties"][quantity]) - ).reshape(1, 1) - else: - self.data[index][quantity_o] = np.vstack( - ( - self.data[index][quantity_o], - np.array( - float(properties["properties"][quantity]) - ).reshape(1, 1), - ), - ) + quantity = "dispersion correction energy" + quantity_o = "dispersion_correction_energy" + # Note need to typecast here because of a bug in the + # qcarchive entry: see issue: https://github.com/MolSSI/QCFractal/issues/766 + if quantity_o not in self.data[index].keys(): + self.data[index][quantity_o] = np.array( + float(properties["properties"][quantity]) + ).reshape(1, 1) + else: + self.data[index][quantity_o] = np.vstack( + ( + self.data[index][quantity_o], + np.array( + float(properties["properties"][quantity]) + ).reshape(1, 1), + ), + ) - quantity = "current gradient" - quantity_o = "dft_total_gradient" - if quantity_o not in self.data[index].keys(): - self.data[index][quantity_o] = np.array( - properties["properties"][quantity] - ).reshape(1, -1, 3) - else: - self.data[index][quantity_o] = np.vstack( - ( - self.data[index][quantity_o], - np.array( - properties["properties"][quantity] - ).reshape(1, -1, 3), + quantity = "current gradient" + quantity_o = "dft_total_gradient" + if quantity_o not in self.data[index].keys(): + self.data[index][quantity_o] = np.array( + properties["properties"][quantity] + ).reshape(1, -1, 3) + else: + self.data[index][quantity_o] = np.vstack( + ( + self.data[index][quantity_o], + np.array( + properties["properties"][quantity] + ).reshape(1, -1, 3), + ) ) - ) - quantity = "dispersion correction gradient" - quantity_o = "dispersion_correction_gradient" - if quantity_o not in self.data[index].keys(): - self.data[index][quantity_o] = np.array( - properties["properties"][quantity] - ).reshape(1, -1, 3) - else: - self.data[index][quantity_o] = np.vstack( - ( - self.data[index][quantity_o], - np.array( - properties["properties"][quantity] - ).reshape(1, -1, 3), + quantity = "dispersion correction gradient" + quantity_o = "dispersion_correction_gradient" + if quantity_o not in self.data[index].keys(): + self.data[index][quantity_o] = np.array( + properties["properties"][quantity] + ).reshape(1, -1, 3) + else: + self.data[index][quantity_o] = np.vstack( + ( + self.data[index][quantity_o], + np.array( + properties["properties"][quantity] + ).reshape(1, -1, 3), + ) ) - ) - quantity = "scf dipole" - quantity_o = "scf_dipole" - if quantity_o not in self.data[index].keys(): - self.data[index][quantity_o] = np.array( - properties["properties"][quantity] - ).reshape(1, 3) - else: - self.data[index][quantity_o] = np.vstack( - ( - self.data[index][quantity_o], - np.array( - properties["properties"][quantity] - ).reshape(1, 3), + quantity = "scf dipole" + quantity_o = "scf_dipole" + if quantity_o not in self.data[index].keys(): + self.data[index][quantity_o] = np.array( + properties["properties"][quantity] + ).reshape(1, 3) + else: + self.data[index][quantity_o] = np.vstack( + ( + self.data[index][quantity_o], + np.array( + properties["properties"][quantity] + ).reshape(1, 3), + ) ) - ) # assign units for datapoint in self.data: @@ -564,6 +589,8 @@ def process( max_conformers_per_record: Optional[int] = None, total_conformers: Optional[int] = None, limit_atomic_species: Optional[list] = None, + max_force: Optional[unit.Quantity] = None, + final_conformer_only=None, n_threads=2, ) -> None: """ @@ -586,7 +613,11 @@ def process( Note defining this will only fetch from the "SPICE PubChem Set 1 Single Points Dataset v1.2" limit_atomic_species: Optional[list] = None, If set to a list of element symbols, records that contain any elements not in this list will be ignored. - n_threads, int, default=6 + max_force: Optional[float], optional, default=None + If set this any confirugrations with a force that exceeds this value will be excluded. + final_conformer_only: Optional[bool], optional, default=None + If set to True, only the final conformer of each record will be processed. + n_threads, int, default=2 Number of concurrent threads for retrieving data from QCArchive Examples -------- @@ -664,6 +695,8 @@ def process( max_conformers_per_record=max_conformers_per_record, total_conformers=total_conformers, atomic_numbers_to_limit=self.atomic_numbers_to_limit, + max_force=max_force, + final_conformer_only=final_conformer_only, ) self._generate_hdf5() diff --git a/modelforge/curation/qm9_curation.py b/modelforge/curation/qm9_curation.py index 847eaa01..67adde87 100644 --- a/modelforge/curation/qm9_curation.py +++ b/modelforge/curation/qm9_curation.py @@ -68,6 +68,9 @@ def _init_dataset_parameters(self) -> None: self.dataset_md5_checksum = data_inputs[self.version_select][ "dataset_md5_checksum" ] + self.dataset_filename = data_inputs[self.version_select]["dataset_filename"] + self.dataset_length = data_inputs[self.version_select]["dataset_length"] + logger.debug( f"Dataset: {self.version_select} version: {data_inputs[self.version_select]['version']}" ) @@ -640,24 +643,22 @@ def process( "max_records and total_conformers cannot be set at the same time." ) - from modelforge.utils.remote import download_from_figshare + from modelforge.utils.remote import download_from_url url = self.dataset_download_url # download the dataset - self.name = download_from_figshare( + download_from_url( url=url, md5_checksum=self.dataset_md5_checksum, output_path=self.local_cache_dir, + output_filename=self.dataset_filename, + length=self.dataset_length, force_download=force_download, ) # clear out the data array before we process self._clear_data() - # process the rest of the dataset - if self.name is None: - raise Exception("Failed to retrieve name of file from figshare.") - # untar the dataset from modelforge.utils.misc import extract_tarred_file @@ -665,7 +666,7 @@ def process( # creating a directory called qm9_xyz_files to hold the contents extract_tarred_file( input_path_dir=self.local_cache_dir, - file_name=self.name, + file_name=self.dataset_filename, output_path_dir=f"{self.local_cache_dir}/qm9_xyz_files", mode="r:bz2", ) diff --git a/modelforge/curation/scripts/curate_PhAlkEthOH.py b/modelforge/curation/scripts/curate_PhAlkEthOH.py index 6dfab740..cecff1ff 100644 --- a/modelforge/curation/scripts/curate_PhAlkEthOH.py +++ b/modelforge/curation/scripts/curate_PhAlkEthOH.py @@ -20,6 +20,8 @@ def PhAlkEthOH_openff_wrapper( max_conformers_per_record=None, total_conformers=None, limit_atomic_species=None, + max_force=None, + final_conformer_only=False, ): """ This curates and processes the SPICE 114 dataset at the OpenFF level of theory into an hdf5 file. @@ -49,7 +51,10 @@ def PhAlkEthOH_openff_wrapper( limit_atomic_species: list, optional, default=None A list of atomic species to limit the dataset to. Any molecules that contain elements outside of this list will be ignored. If not defined, no filtering by atomic species will be performed. - + max_force: float, optional, default=None + The maximum force to allow in the dataset. Any conformers with forces greater than this value will be ignored. + final_conformer_only: bool, optional, default=False + If True, only the final conformer for each molecule will be processed. If False, all conformers will be processed. """ from modelforge.curation.phalkethoh_curation import PhAlkEthOHCuration @@ -67,6 +72,8 @@ def PhAlkEthOH_openff_wrapper( total_conformers=total_conformers, limit_atomic_species=limit_atomic_species, n_threads=1, + max_force=max_force, + final_conformer_only=final_conformer_only, ) print(f"Total records: {PhAlkEthOH_dataset.total_records}") print(f"Total conformers: {PhAlkEthOH_dataset.total_conformers}") @@ -74,6 +81,8 @@ def PhAlkEthOH_openff_wrapper( def main(): + from openff.units import unit + # define the location where to store and output the files import os @@ -83,9 +92,9 @@ def main(): # We'll want to provide some simple means of versioning # if we make updates to either the underlying dataset, curation modules, or parameters given to the code - version = "0" + version = "1" # version of the dataset to curate - version_select = f"v_{version}" + version_select = f"v_0" # curate dataset with 1000 total conformers, max of 10 conformers per record hdf5_file_name = f"PhAlkEthOH_openff_dataset_v{version}_ntc_1000.hdf5" @@ -99,6 +108,7 @@ def main(): max_records=1000, total_conformers=1000, max_conformers_per_record=10, + max_force=1.0 * unit.hartree / unit.bohr, ) # curate the full dataset @@ -110,6 +120,36 @@ def main(): local_cache_dir, force_download=False, version_select=version_select, + max_force=1.0 * unit.hartree / unit.bohr, + ) + + # curate dataset with 1000 total conformers, max of 10 conformers per record + hdf5_file_name = f"PhAlkEthOH_openff_dataset_v{version}_ntc_1000_minimal.hdf5" + + PhAlkEthOH_openff_wrapper( + hdf5_file_name, + output_file_dir, + local_cache_dir, + force_download=False, + version_select=version_select, + max_records=1000, + total_conformers=1000, + max_conformers_per_record=10, + max_force=1.0 * unit.hartree / unit.bohr, + final_conformer_only=True, + ) + + # curate the full dataset + hdf5_file_name = f"PhAlkEthOH_openff_dataset_v{version}_minimal.hdf5" + print("total dataset") + PhAlkEthOH_openff_wrapper( + hdf5_file_name, + output_file_dir, + local_cache_dir, + force_download=False, + version_select=version_select, + max_force=1.0 * unit.hartree / unit.bohr, + final_conformer_only=True, ) diff --git a/modelforge/curation/scripts/curate_spice114.py b/modelforge/curation/scripts/curate_spice114.py index aa68d7d2..4a43869d 100644 --- a/modelforge/curation/scripts/curate_spice114.py +++ b/modelforge/curation/scripts/curate_spice114.py @@ -68,7 +68,7 @@ def spice114_wrapper( """ from modelforge.curation.spice_1_curation import SPICE1Curation - spice_114 = SPICE114Curation( + spice_114 = SPICE1Curation( hdf5_file_name=hdf5_file_name, output_file_dir=output_file_dir, local_cache_dir=local_cache_dir, @@ -95,9 +95,9 @@ def main(): # We'll want to provide some simple means of versioning # if we make updates to either the underlying dataset, curation modules, or parameters given to the code - version = "0" + version = "1" # version of the dataset to curate - version_select = f"v_{version}" + version_select = f"v_0" # version v_0 corresponds to SPICE 1.1.4 release # curate ANI1x test dataset with 1000 total conformers, max of 10 conformers per record diff --git a/modelforge/curation/scripts/curate_spice114_openff.py b/modelforge/curation/scripts/curate_spice114_openff.py index cb589a83..d3f2cd81 100644 --- a/modelforge/curation/scripts/curate_spice114_openff.py +++ b/modelforge/curation/scripts/curate_spice114_openff.py @@ -69,9 +69,9 @@ def spice_114_openff_wrapper( """ - from modelforge.curation.spice_1_openff_curation import SPICEOpenFFCuration + from modelforge.curation.spice_1_openff_curation import SPICE1OpenFFCuration - spice_dataset = SPICEOpenFFCuration( + spice_dataset = SPICE1OpenFFCuration( hdf5_file_name=hdf5_file_name, output_file_dir=output_file_dir, local_cache_dir=local_cache_dir, @@ -99,9 +99,9 @@ def main(): # We'll want to provide some simple means of versioning # if we make updates to either the underlying dataset, curation modules, or parameters given to the code - version = "0" + version = "1" # version of the dataset to curate - version_select = f"v_{version}" + version_select = f"v_0" ani2x_elements = ["H", "C", "N", "O", "F", "Cl", "S"] diff --git a/modelforge/curation/scripts/curate_spice2.py b/modelforge/curation/scripts/curate_spice2.py index 1cd98985..8d9f72a6 100644 --- a/modelforge/curation/scripts/curate_spice2.py +++ b/modelforge/curation/scripts/curate_spice2.py @@ -94,9 +94,9 @@ def main(): # We'll want to provide some simple means of versioning # if we make updates to either the underlying dataset, curation modules, or parameters given to the code - version = "0" + version = "1" # version of the dataset to curate - version_select = f"v_{version}" + version_select = f"v_0" # version v_0 corresponds to SPICE 2.0.1 diff --git a/modelforge/curation/spice_1_curation.py b/modelforge/curation/spice_1_curation.py index 62ada82e..df31c211 100644 --- a/modelforge/curation/spice_1_curation.py +++ b/modelforge/curation/spice_1_curation.py @@ -2,6 +2,7 @@ from typing import Optional from loguru import logger from openff.units import unit +import numpy as np class SPICE1Curation(DatasetCuration): @@ -68,6 +69,9 @@ def _init_dataset_parameters(self): self.dataset_md5_checksum = data_inputs[self.version_select][ "dataset_md5_checksum" ] + self.dataset_filename = data_inputs[self.version_select]["dataset_filename"] + self.dataset_length = data_inputs[self.version_select]["dataset_length"] + logger.debug( f"Dataset: {self.version_select} version: {data_inputs[self.version_select]['version']}" ) @@ -160,7 +164,7 @@ def _init_record_entries_series(self): "n_configs": "single_rec", "smiles": "single_rec", "subset": "single_rec", - "total_charge": "single_rec", + "total_charge": "series_mol", "geometry": "series_atom", "dft_total_energy": "series_mol", "dft_total_gradient": "series_atom", @@ -196,7 +200,7 @@ def _calculate_reference_charge(self, smiles: str) -> unit.Quantity: rdmol = Chem.MolFromSmiles(smiles, sanitize=False) total_charge = sum(atom.GetFormalCharge() for atom in rdmol.GetAtoms()) - return int(total_charge) * unit.elementary_charge + return np.array([int(total_charge)]) * unit.elementary_charge def _process_downloaded( self, @@ -304,9 +308,17 @@ def _process_downloaded( ds_temp[param_out] = temp * param_unit else: ds_temp[param_out] = temp - ds_temp["total_charge"] = self._calculate_reference_charge( - ds_temp["smiles"] + total_charge = self._calculate_reference_charge(ds_temp["smiles"]) + + total_charge_reshaped = ( + np.repeat(total_charge.m, conformers_per_record).reshape( + conformers_per_record, -1 + ) + * total_charge.u ) + + ds_temp["total_charge"] = total_charge_reshaped + ds_temp["dft_total_force"] = -ds_temp["dft_total_gradient"] # check if the record contains only the elements we are interested in @@ -369,15 +381,17 @@ def process( raise Exception( "max_records and total_conformers cannot be set at the same time." ) - from modelforge.utils.remote import download_from_zenodo + from modelforge.utils.remote import download_from_url url = self.dataset_download_url # download the dataset - self.name = download_from_zenodo( + download_from_url( url=url, md5_checksum=self.dataset_md5_checksum, output_path=self.local_cache_dir, + output_filename=self.dataset_filename, + length=self.dataset_length, force_download=force_download, ) @@ -394,11 +408,10 @@ def process( self.atomic_numbers_to_limit = None # process the rest of the dataset - if self.name is None: - raise Exception("Failed to retrieve name of file from zenodo.") + self._process_downloaded( self.local_cache_dir, - self.name, + self.dataset_filename, max_records, max_conformers_per_record, total_conformers, diff --git a/modelforge/curation/spice_1_openff_curation.py b/modelforge/curation/spice_1_openff_curation.py index 84a1c193..ab382dac 100644 --- a/modelforge/curation/spice_1_openff_curation.py +++ b/modelforge/curation/spice_1_openff_curation.py @@ -6,6 +6,7 @@ retry = import_("retry").retry from tqdm import tqdm from openff.units import unit +import numpy as np class SPICE1OpenFFCuration(DatasetCuration): @@ -150,7 +151,7 @@ def _init_record_entries_series(self): "name": "single_rec", "dataset_name": "single_rec", "source": "single_rec", - "total_charge": "single_rec", + "total_charge": "series_mol", "atomic_numbers": "single_atom", "n_configs": "single_rec", "reference_energy": "single_rec", @@ -348,7 +349,7 @@ def _calculate_reference_energy_and_charge( return ( sum(atom_energy[s][c] for s, c in zip(symbol, charge)) * unit.hartree, - int(total_charge) * unit.elementary_charge, + np.array([int(total_charge)]) * unit.elementary_charge, ) def _sort_keys(self, non_error_keys: List[str]) -> Tuple[List[str], Dict[str, str]]: @@ -653,6 +654,12 @@ def _process_downloaded( np.array(val["properties"][quantity]).reshape(1, -1, 3), ) ) + for datapoint in self.data: + n_configs = datapoint["n_configs"] + total_charge = datapoint["total_charge"] + datapoint["total_charge"] = ( + np.repeat(total_charge.m, n_configs) * total_charge.u + ) # assign units for datapoint in self.data: for key in datapoint.keys(): @@ -733,6 +740,8 @@ def _process_downloaded( datapoint["mbis_charges"] = datapoint["mbis_charges"][0:n_conformers] datapoint["scf_dipole"] = datapoint["scf_dipole"][0:n_conformers] + datapoint["total_charge"] = datapoint["total_charge"][0:n_conformers] + temp_data.append(datapoint) conformers_count += n_conformers self.data = temp_data diff --git a/modelforge/curation/spice_2_curation.py b/modelforge/curation/spice_2_curation.py index 452b34b1..4783c10b 100644 --- a/modelforge/curation/spice_2_curation.py +++ b/modelforge/curation/spice_2_curation.py @@ -2,6 +2,7 @@ from typing import Optional from loguru import logger from openff.units import unit +import numpy as np class SPICE2Curation(DatasetCuration): @@ -65,6 +66,9 @@ def _init_dataset_parameters(self): self.dataset_md5_checksum = data_inputs[self.version_select][ "dataset_md5_checksum" ] + self.dataset_filename = data_inputs[self.version_select]["dataset_filename"] + self.dataset_length = data_inputs[self.version_select]["dataset_length"] + logger.debug( f"Dataset: {self.version_select} version: {data_inputs[self.version_select]['version']}" ) @@ -155,7 +159,7 @@ def _init_record_entries_series(self): "n_configs": "single_rec", "smiles": "single_rec", "subset": "single_rec", - "total_charge": "single_rec", + "total_charge": "series_mol", "geometry": "series_atom", "dft_total_energy": "series_mol", "dft_total_gradient": "series_atom", @@ -191,7 +195,7 @@ def _calculate_reference_charge(self, smiles: str) -> unit.Quantity: rdmol = Chem.MolFromSmiles(smiles, sanitize=False) total_charge = sum(atom.GetFormalCharge() for atom in rdmol.GetAtoms()) - return int(total_charge) * unit.elementary_charge + return np.array([int(total_charge)]) * unit.elementary_charge def _process_downloaded( self, @@ -249,74 +253,82 @@ def _process_downloaded( # Extract the total number of conformations for a given molecule conformers_per_record = hf[name]["conformations"].shape[0] - keys_list = list(hf[name].keys()) + if conformers_per_record != 0: + keys_list = list(hf[name].keys()) - # temp dictionary for ANI-1x and ANI-1ccx data - ds_temp = {} + # temp dictionary for ANI-1x and ANI-1ccx data + ds_temp = {} - ds_temp["name"] = f"{name}" - ds_temp["smiles"] = hf[name]["smiles"][()][0].decode("utf-8") - ds_temp["atomic_numbers"] = hf[name]["atomic_numbers"][()].reshape( - -1, 1 - ) - if max_conformers_per_record is not None: - conformers_per_record = min( - conformers_per_record, - max_conformers_per_record, + ds_temp["name"] = f"{name}" + ds_temp["smiles"] = hf[name]["smiles"][()][0].decode("utf-8") + ds_temp["atomic_numbers"] = hf[name]["atomic_numbers"][()].reshape( + -1, 1 ) - if total_conformers is not None: - conformers_per_record = min( - conformers_per_record, total_conformers - conformers_counter + if max_conformers_per_record is not None: + conformers_per_record = min( + conformers_per_record, + max_conformers_per_record, + ) + if total_conformers is not None: + conformers_per_record = min( + conformers_per_record, total_conformers - conformers_counter + ) + + ds_temp["n_configs"] = conformers_per_record + + # param_in is the name of the entry, param_data contains input (u_in) and output (u_out) units + for param_in, param_data in self.qm_parameters.items(): + # for consistency between datasets, we will all the particle positions "geometry" + param_out = param_in + if param_in == "geometry": + param_in = "conformations" + + if param_in in keys_list: + temp = hf[name][param_in][()] + if param_in in need_to_reshape: + temp = temp.reshape(-1, 1) + + temp = temp[0:conformers_per_record] + param_unit = param_data["u_in"] + if param_unit is not None: + # check that units in the hdf5 file match those we have defined in self.qm_parameters + try: + assert ( + hf[name][param_in].attrs["units"] + == param_data["u_in"] + ) + except: + msg1 = f'unit mismatch: units in hdf5 file: {hf[name][param_in].attrs["units"]},' + msg2 = f'units defined in curation class: {param_data["u_in"]}.' + + raise AssertionError(f"{msg1} {msg2}") + + ds_temp[param_out] = temp * param_unit + else: + ds_temp[param_out] = temp + total_charge = self._calculate_reference_charge(ds_temp["smiles"]) + + total_charge_reshaped = ( + np.repeat(total_charge.m, conformers_per_record).reshape( + conformers_per_record, -1 + ) + * total_charge.u ) - ds_temp["n_configs"] = conformers_per_record - - # param_in is the name of the entry, param_data contains input (u_in) and output (u_out) units - for param_in, param_data in self.qm_parameters.items(): - # for consistency between datasets, we will all the particle positions "geometry" - param_out = param_in - if param_in == "geometry": - param_in = "conformations" - - if param_in in keys_list: - temp = hf[name][param_in][()] - if param_in in need_to_reshape: - temp = temp.reshape(-1, 1) - - temp = temp[0:conformers_per_record] - param_unit = param_data["u_in"] - if param_unit is not None: - # check that units in the hdf5 file match those we have defined in self.qm_parameters - try: - assert ( - hf[name][param_in].attrs["units"] - == param_data["u_in"] - ) - except: - msg1 = f'unit mismatch: units in hdf5 file: {hf[name][param_in].attrs["units"]},' - msg2 = f'units defined in curation class: {param_data["u_in"]}.' - - raise AssertionError(f"{msg1} {msg2}") - - ds_temp[param_out] = temp * param_unit - else: - ds_temp[param_out] = temp - ds_temp["total_charge"] = self._calculate_reference_charge( - ds_temp["smiles"] - ) - ds_temp["dft_total_force"] = -ds_temp["dft_total_gradient"] - - # check if the record contains only the elements we are interested in - # if this has been defined - add_to_record = True - if atomic_numbers_to_limit is not None: - add_to_record = set(ds_temp["atomic_numbers"].flatten()).issubset( - atomic_numbers_to_limit - ) + ds_temp["total_charge"] = total_charge_reshaped + ds_temp["dft_total_force"] = -ds_temp["dft_total_gradient"] - if add_to_record: - self.data.append(ds_temp) - conformers_counter += conformers_per_record + # check if the record contains only the elements we are interested in + # if this has been defined + add_to_record = True + if atomic_numbers_to_limit is not None: + add_to_record = set( + ds_temp["atomic_numbers"].flatten() + ).issubset(atomic_numbers_to_limit) + + if add_to_record: + self.data.append(ds_temp) + conformers_counter += conformers_per_record if self.convert_units: self._convert_units() @@ -367,15 +379,17 @@ def process( raise ValueError( "max_records and total_conformers cannot be set at the same time." ) - from modelforge.utils.remote import download_from_zenodo + from modelforge.utils.remote import download_from_url url = self.dataset_download_url # download the dataset - self.name = download_from_zenodo( + download_from_url( url=url, md5_checksum=self.dataset_md5_checksum, output_path=self.local_cache_dir, + output_filename=self.dataset_filename, + length=self.dataset_length, force_download=force_download, ) @@ -393,11 +407,10 @@ def process( self.atomic_numbers_to_limit = None # process the rest of the dataset - if self.name is None: - raise Exception("Failed to retrieve name of file from zenodo.") + self._process_downloaded( self.local_cache_dir, - self.name, + self.dataset_filename, max_records, max_conformers_per_record, total_conformers, diff --git a/modelforge/curation/yaml_files/ani1x_curation.yaml b/modelforge/curation/yaml_files/ani1x_curation.yaml index 5c933eba..8ca1c3cc 100644 --- a/modelforge/curation/yaml_files/ani1x_curation.yaml +++ b/modelforge/curation/yaml_files/ani1x_curation.yaml @@ -4,3 +4,5 @@ v_0: version: 0 dataset_download_url: https://springernature.figshare.com/ndownloader/files/18112775 dataset_md5_checksum: 98090dd6679106da861f52bed825ffb7 + dataset_length: 5590846027 + dataset_filename: ani1xrelease.h5 diff --git a/modelforge/curation/yaml_files/ani2x_curation.yaml b/modelforge/curation/yaml_files/ani2x_curation.yaml index eb24880c..8a494f73 100644 --- a/modelforge/curation/yaml_files/ani2x_curation.yaml +++ b/modelforge/curation/yaml_files/ani2x_curation.yaml @@ -4,4 +4,6 @@ v_0: version: 0 dataset_download_url: https://zenodo.org/records/10108942/files/ANI-2x-wB97X-631Gd.tar.gz dataset_md5_checksum: cb1d9effb3d07fc1cc6ced7cd0b1e1f2 + dataset_length: 3705675413 + dataset_filename: ANI-2x-wB97X-631Gd.tar.gz diff --git a/modelforge/curation/yaml_files/qm9_curation.yaml b/modelforge/curation/yaml_files/qm9_curation.yaml index ebd342d2..9b05555c 100644 --- a/modelforge/curation/yaml_files/qm9_curation.yaml +++ b/modelforge/curation/yaml_files/qm9_curation.yaml @@ -4,3 +4,5 @@ v_0: version: 0 dataset_download_url: https://springernature.figshare.com/ndownloader/files/3195389 dataset_md5_checksum: ad1ebd51ee7f5b3a6e32e974e5d54012 + dataset_length: 86144227 + dataset_filename: dsgdb9nsd.xyz.tar.bz2 diff --git a/modelforge/curation/yaml_files/spice1_curation.yaml b/modelforge/curation/yaml_files/spice1_curation.yaml index 6285f26a..4be97814 100644 --- a/modelforge/curation/yaml_files/spice1_curation.yaml +++ b/modelforge/curation/yaml_files/spice1_curation.yaml @@ -5,3 +5,5 @@ v_0: notes: SPICE 1.1.4 release dataset_download_url: https://zenodo.org/records/8222043/files/SPICE-1.1.4.hdf5 dataset_md5_checksum: f27d4c81da0e37d6547276bf6b4ae6a1 + dataset_length: 16058156944 + dataset_filename: SPICE-1.1.4.hdf5 diff --git a/modelforge/curation/yaml_files/spice2_curation.yaml b/modelforge/curation/yaml_files/spice2_curation.yaml index 645cbd60..6ed537c0 100644 --- a/modelforge/curation/yaml_files/spice2_curation.yaml +++ b/modelforge/curation/yaml_files/spice2_curation.yaml @@ -5,3 +5,5 @@ v_0: notes: SPICE 2.0.1 release dataset_download_url: https://zenodo.org/records/10975225/files/SPICE-2.0.1.hdf5 dataset_md5_checksum: bfba2224b6540e1390a579569b475510 + dataset_length: 37479271148 + dataset_filename: SPICE-2.0.1.hdf5 diff --git a/modelforge/custom_types.py b/modelforge/custom_types.py new file mode 100644 index 00000000..31e87126 --- /dev/null +++ b/modelforge/custom_types.py @@ -0,0 +1,14 @@ +from typing import Literal + +ModelType = Literal[ + "ANI2x", "PhysNet", "SchNet", "PaiNN", "SAKE", "TensorNet", "AimNet2" +] +DatasetType = Literal[ + "QM9", + "ANI1X", + "ANI2X", + "SPICE1", + "SPICE2", + "SPICE1_OPENFF", + "PhAlkEthOH", +] diff --git a/modelforge/dataset/__init__.py b/modelforge/dataset/__init__.py index aef82aae..570dd7cb 100644 --- a/modelforge/dataset/__init__.py +++ b/modelforge/dataset/__init__.py @@ -1,4 +1,5 @@ -# defines the interaction with public datasets +""" Module that contains classes and function for loading and processing of datasets. """ + from .qm9 import QM9Dataset from .ani1x import ANI1xDataset from .ani2x import ANI2xDataset diff --git a/modelforge/dataset/ani1x.py b/modelforge/dataset/ani1x.py index e05a5dd2..7c7c0053 100644 --- a/modelforge/dataset/ani1x.py +++ b/modelforge/dataset/ani1x.py @@ -1,3 +1,7 @@ +""" +Data class for handling ANI1x dataset. +""" + from typing import List from .dataset import HDF5Dataset @@ -65,14 +69,25 @@ class ANI1xDataset(HDF5Dataset): F="wb97x_dz.forces", ) + # for simplicifty, commenting out those properties that are cannot be used in our current implementation _available_properties = [ "geometry", "atomic_numbers", "wb97x_dz.energy", "wb97x_dz.forces", - "wb97x_dz.cm5_charges", + # "wb97x_dz.cm5_charges", ] # All properties within the datafile, aside from SMILES/inchi. + # Mapping of available properties to the associated PropertyNames + _available_properties_association = { + "geometry": "positions", + "atomic_numbers": "atomic_numbers", + "wb97x_dz.energy": "E", + "wb97x_dz.forces": "F", + } # note cm5_charges are not actually total charge, as they as per-atom charges, but unit is the same + # we need to discuss if we want to allow per_atom_charges at all and if we want to preprocess + # per atom charges into dipole moment + def __init__( self, dataset_name: str = "ANI1x", diff --git a/modelforge/dataset/ani2x.py b/modelforge/dataset/ani2x.py index dff9b178..7b8a588f 100644 --- a/modelforge/dataset/ani2x.py +++ b/modelforge/dataset/ani2x.py @@ -1,3 +1,7 @@ +""" +Data class for handling ANI2x data. +""" + from typing import List from .dataset import HDF5Dataset @@ -44,7 +48,10 @@ class ANI2xDataset(HDF5Dataset): from modelforge.utils import PropertyNames _property_names = PropertyNames( - atomic_numbers="atomic_numbers", positions="geometry", E="energies", F="forces" + atomic_numbers="atomic_numbers", + positions="geometry", + E="energies", + F="forces", ) _available_properties = [ @@ -54,6 +61,14 @@ class ANI2xDataset(HDF5Dataset): "forces", ] # All properties within the datafile, aside from SMILES/inchi. + # Mapping of available properties to the associated PropertyNames + _available_properties_association = { + "geometry": "positions", + "atomic_numbers": "atomic_numbers", + "energies": "E", + "forces": "F", + } + def __init__( self, dataset_name: str = "ANI2x", diff --git a/modelforge/dataset/dataset.py b/modelforge/dataset/dataset.py index 6cefab07..5770a2d0 100644 --- a/modelforge/dataset/dataset.py +++ b/modelforge/dataset/dataset.py @@ -1,6 +1,9 @@ +""" +This module contains classes and functions for managing datasets. +""" + import os -from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union, NamedTuple +from typing import TYPE_CHECKING, Dict, List, Optional import numpy as np import pytorch_lightning as pl @@ -10,186 +13,74 @@ from torch.utils.data import DataLoader from modelforge.dataset.utils import RandomRecordSplittingStrategy, SplittingStrategy +from modelforge.utils.misc import lock_with_attribute from modelforge.utils.prop import PropertyNames +from modelforge.utils.prop import BatchData +from modelforge.utils.prop import NNPInput +from modelforge.utils.prop import Metadata if TYPE_CHECKING: from modelforge.potential.processing import AtomicSelfEnergies +from enum import Enum -@dataclass(frozen=False) -class Metadata: - """ - A NamedTuple to structure the inputs for neural network potentials. +from pydantic import BaseModel, ConfigDict, Field - Parameters - ---------- + +class CaseInsensitiveEnum(str, Enum): + """ + Enum class that allows case-insensitive comparison of its members. """ - E: torch.Tensor - atomic_subsystem_counts: torch.Tensor - atomic_subsystem_indices_referencing_dataset: torch.Tensor - number_of_atoms: int - F: torch.Tensor = torch.tensor([], dtype=torch.float32) + @classmethod + def _missing_(cls, value): + for member in cls: + if member.value.lower() == value.lower(): + return member + return super()._missing_(value) - def to( - self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None - ): - """Move all tensors in this instance to the specified device.""" - if device: - self.E = self.E.to(device) - self.F = self.F.to(device) - self.atomic_subsystem_counts = self.atomic_subsystem_counts.to(device) - self.atomic_subsystem_indices_referencing_dataset = ( - self.atomic_subsystem_indices_referencing_dataset.to(device) - ) - if dtype: - self.E = self.E.to(dtype) - self.F = self.F.to(dtype) - return self + +class DataSetName(CaseInsensitiveEnum): + QM9 = "QM9" + ANI1X = "ANI1X" + ANI2X = "ANI2X" + SPICE1 = "SPICE1" + SPICE2 = "SPICE2" + SPICE1_OPENFF = "SPICE1_OPENFF" + PHALKETHOH = "PhAlkEthOH" -@dataclass -class NNPInput: +class DatasetParameters(BaseModel): """ - A dataclass to structure the inputs for neural network potentials. + Class to hold the dataset parameters. Attributes ---------- - atomic_numbers : torch.Tensor - A 1D tensor containing atomic numbers for each atom in the system(s). - Shape: [num_atoms], where `num_atoms` is the total number of atoms across all systems. - positions : torch.Tensor - A 2D tensor of shape [num_atoms, 3], representing the XYZ coordinates of each atom. - atomic_subsystem_indices : torch.Tensor - A 1D tensor mapping each atom to its respective subsystem or molecule. - This allows for calculations involving multiple molecules or subsystems within the same batch. - Shape: [num_atoms]. - total_charge : torch.Tensor - A tensor with the total charge of molecule. - Shape: [num_systems], where `num_systems` is the number of molecules. + dataset_name : DataSetName + The name of the dataset. + version_select : str + The version of the dataset to use. + num_workers : int + The number of workers to use for the DataLoader. + pin_memory : bool + Whether to pin memory for the DataLoader. + regenerate_processed_cache : bool + Whether to regenerate the processed cache. """ - atomic_numbers: torch.Tensor - positions: Union[torch.Tensor, Quantity] - atomic_subsystem_indices: torch.Tensor - total_charge: torch.Tensor - pair_list: Optional[torch.Tensor] = None - - def to( - self, - *, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - """Move all tensors in this instance to the specified device/dtype.""" - - if device: - self.atomic_numbers = self.atomic_numbers.to(device) - self.positions = self.positions.to(device) - self.atomic_subsystem_indices = self.atomic_subsystem_indices.to(device) - self.total_charge = self.total_charge.to(device) - if self.pair_list is not None: - self.pair_list = self.pair_list.to(device) - if dtype: - self.positions = self.positions.to(dtype) - return self - - def __post_init__(self): - # Set dtype and convert units if necessary - self.atomic_numbers = self.atomic_numbers.to(torch.int32) - self.atomic_subsystem_indices = self.atomic_subsystem_indices.to(torch.int32) - self.total_charge = self.total_charge.to(torch.int32) - - # Unit conversion for positions - if isinstance(self.positions, Quantity): - positions = self.positions.to(unit.nanometer).m - self.positions = torch.tensor( - positions, dtype=torch.float32, requires_grad=True - ) - - # Validate inputs - self._validate_inputs() - - def _validate_inputs(self): - if self.atomic_numbers.dim() != 1: - raise ValueError("atomic_numbers must be a 1D tensor") - if self.positions.dim() != 2 or self.positions.size(1) != 3: - raise ValueError("positions must be a 2D tensor with shape [num_atoms, 3]") - if self.atomic_subsystem_indices.dim() != 1: - raise ValueError("atomic_subsystem_indices must be a 1D tensor") - if self.total_charge.dim() != 1: - raise ValueError("total_charge must be a 1D tensor") - - # Optionally, check that the lengths match if required - if len(self.positions) != len(self.atomic_numbers): - raise ValueError( - "The size of atomic_numbers and the first dimension of positions must match" - ) - if len(self.positions) != len(self.atomic_subsystem_indices): - raise ValueError( - "The size of atomic_subsystem_indices and the first dimension of positions must match" - ) - - def as_namedtuple(self) -> NamedTuple: - """Export the dataclass fields and values as a named tuple.""" - - import collections - from dataclasses import dataclass, fields - - NNPInputTuple = collections.namedtuple( - "NNPInputTuple", [field.name for field in fields(self)] - ) - return NNPInputTuple(*[getattr(self, field.name) for field in fields(self)]) - - def as_jax_namedtuple(self) -> NamedTuple: - """Export the dataclass fields and values as a named tuple. - Convert pytorch tensors to jax arrays.""" - - from dataclasses import dataclass, fields - import collections - from modelforge.utils.io import import_ - - convert_to_jax = import_("pytorch2jax").pytorch2jax.convert_to_jax - # from pytorch2jax.pytorch2jax import convert_to_jax - - NNPInputTuple = collections.namedtuple( - "NNPInputTuple", [field.name for field in fields(self)] - ) - return NNPInputTuple( - *[convert_to_jax(getattr(self, field.name)) for field in fields(self)] - ) - - -@dataclass -class BatchData: - nnp_input: NNPInput - metadata: Metadata - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - self.nnp_input = self.nnp_input.to(device=device, dtype=dtype) - self.metadata = self.metadata.to(device=device, dtype=dtype) - return self - + model_config = ConfigDict( + use_enum_values=True, arbitrary_types_allowed=True, validate_assignment=True + ) -class TorchDataset(torch.utils.data.Dataset[Dict[str, torch.Tensor]]): - """ - Wraps a numpy dataset to make it compatible with PyTorch DataLoader. + dataset_name: DataSetName + version_select: str + num_workers: int = Field(gt=0) + pin_memory: bool + regenerate_processed_cache: bool = False - Parameters - ---------- - dataset : np.lib.npyio.NpzFile - The underlying numpy dataset. - property_name : PropertyNames - Names of the properties to extract from the dataset. - preloaded : bool, optional - If True, converts properties to PyTorch tensors ahead of time. Default is False. - - """ +# Define the input class +class TorchDataset(torch.utils.data.Dataset[BatchData]): def __init__( self, dataset: np.lib.npyio.NpzFile, @@ -197,68 +88,37 @@ def __init__( preloaded: bool = False, ): """ - Initializes the TorchDataset with a numpy dataset and property names. + Wraps a numpy dataset to make it compatible with PyTorch DataLoader. Parameters ---------- dataset : np.lib.npyio.NpzFile - The numpy dataset to wrap. + The underlying numpy dataset. property_name : PropertyNames - The property names to extract from the dataset for use in PyTorch. + Names of the properties to extract from the dataset. preloaded : bool, optional - If set to True, properties are preloaded as PyTorch tensors. Default is False. + If True, converts properties to PyTorch tensors ahead of time. Default + is False. """ - - self.properties_of_interest = {} - - self.properties_of_interest["atomic_numbers"] = torch.from_numpy( - dataset[property_name.atomic_numbers].flatten() - ).to(torch.int32) - self.properties_of_interest["positions"] = torch.from_numpy( - dataset[property_name.positions] - ).to(torch.float32) - self.properties_of_interest["E"] = torch.from_numpy( - dataset[property_name.E] - ).to(torch.float64) - - if property_name.total_charge is not None: - self.properties_of_interest["total_charge"] = torch.from_numpy( - dataset[property_name.total_charge] - ).to(torch.int32) - else: - # this is a per atom property, so it will match the first dimension of the geometry - self.properties_of_interest["total_charge"] = torch.zeros( - (dataset[property_name.positions].shape[0], 1) - ).to(torch.int32) - - if property_name.F is not None: - self.properties_of_interest["F"] = torch.from_numpy( - dataset[property_name.F] - ) - else: - # a per atom property in each direction, so it will match geometry - self.properties_of_interest["F"] = torch.zeros( - dataset[property_name.positions].shape - ) + super().__init__() + self.preloaded = preloaded + self.properties_of_interest = self._load_properties(dataset, property_name) self.number_of_records = len(dataset["atomic_subsystem_counts"]) - self.properties_of_interest["pair_list"] = None self.number_of_atoms = len(dataset["atomic_numbers"]) + self.length = len(self.properties_of_interest["E"]) + + # Prepare indices for atom and conformer data + self._prepare_indices(dataset) + def _prepare_indices(self, dataset: np.lib.npyio.NpzFile): + """Prepare indices for atom and conformer data.""" single_atom_start_idxs_by_rec = np.concatenate( [np.array([0]), np.cumsum(dataset["atomic_subsystem_counts"])] ) - # length: n_records + 1 - self.series_mol_start_idxs_by_rec = np.concatenate( [np.array([0]), np.cumsum(dataset["n_confs"])] ) - # length: n_records + 1 - - if len(single_atom_start_idxs_by_rec) != len(self.series_mol_start_idxs_by_rec): - raise ValueError( - "Number of records in `atomic_subsystem_counts` and `n_confs` do not match." - ) self.single_atom_start_idxs_by_conf = np.repeat( single_atom_start_idxs_by_rec[: self.number_of_records], dataset["n_confs"] @@ -267,7 +127,6 @@ def __init__( single_atom_start_idxs_by_rec[1 : self.number_of_records + 1], dataset["n_confs"], ) - # length: n_conformers self.series_atom_start_idxs_by_conf = np.concatenate( [ @@ -277,10 +136,47 @@ def __init__( ), ] ) - # length: n_conformers + 1 - self.length = len(self.properties_of_interest["E"]) - self.preloaded = preloaded + def _load_properties( + self, dataset: np.lib.npyio.NpzFile, property_name: PropertyNames + ) -> Dict[str, torch.Tensor]: + """Load properties from the dataset.""" + properties = { + "atomic_numbers": torch.from_numpy( + dataset[property_name.atomic_numbers].flatten() + ).to(torch.int32), + "positions": torch.from_numpy(dataset[property_name.positions]).to( + torch.float32 + ), + "E": torch.from_numpy(dataset[property_name.E]).to(torch.float64), + } + + properties["total_charge"] = ( + torch.from_numpy(dataset[property_name.total_charge]) + .to(torch.int32) + .unsqueeze(-1) + if property_name.total_charge is not None + and False # FIXME: as soon as I figured out how to make this to a per data point property + else torch.zeros((dataset[property_name.E].shape[0], 1), dtype=torch.int32) + ) + + properties["F"] = ( + torch.from_numpy(dataset[property_name.F]) + if property_name.F is not None + else torch.zeros_like(properties["positions"]) + ) + + properties["dipole_moment"] = ( + torch.from_numpy(dataset[property_name.dipole_moment]) + if property_name.dipole_moment is not None + else torch.zeros( + (dataset[property_name.E].shape[0], 3), dtype=torch.float32 + ) + ) + + properties["pair_list"] = None # Placeholder for pair list + + return properties def __len__(self) -> int: """ @@ -325,26 +221,34 @@ def __setitem__(self, idx: int, value: Dict[str, torch.Tensor]) -> None: for key, val in value.items(): self.properties_of_interest[key][idx] = val - def __getitem__(self, idx: int) -> BatchData: - """ - Fetch a dictionary of the values for the properties of interest for a given conformer index. - - Parameters - ---------- - idx : int - Index of the molecule to fetch data for. + def _set_pairlist(self, idx: int): + # pairlist is set here (instead of l279) because it is not a default property + if self.properties_of_interest["pair_list"] is None: + pair_list = None + else: + pair_list_indices_start = self.properties_of_interest["number_of_pairs"][ + idx + ] + pair_list_indices_end = self.properties_of_interest["number_of_pairs"][ + idx + 1 + ] + pair_list = self.properties_of_interest["pair_list"][ + :, pair_list_indices_start:pair_list_indices_end + ] + return pair_list - Returns - ------- - BatchData instance representing the data for one conformer. - """ + def __getitem__(self, idx: int) -> BatchData: + """Fetch data for a given conformer index.""" series_atom_start_idx = self.series_atom_start_idxs_by_conf[idx] series_atom_end_idx = self.series_atom_start_idxs_by_conf[idx + 1] single_atom_start_idx = self.single_atom_start_idxs_by_conf[idx] single_atom_end_idx = self.single_atom_end_idxs_by_conf[idx] + atomic_numbers = self.properties_of_interest["atomic_numbers"][ single_atom_start_idx:single_atom_end_idx ] + + # get properties (Note that default properties are set in l279) positions = self.properties_of_interest["positions"][ series_atom_start_idx:series_atom_end_idx ] @@ -352,35 +256,24 @@ def __getitem__(self, idx: int) -> BatchData: F = self.properties_of_interest["F"][series_atom_start_idx:series_atom_end_idx] total_charge = self.properties_of_interest["total_charge"][idx] number_of_atoms = len(atomic_numbers) - if self.properties_of_interest["pair_list"] is None: - pair_list = None - else: - pair_list_indices_start = self.properties_of_interest["number_of_pairs"][ - idx - ] - pair_list_indices_end = self.properties_of_interest["number_of_pairs"][ - idx + 1 - ] - pair_list = self.properties_of_interest["pair_list"][ - :, pair_list_indices_start:pair_list_indices_end - ] + dipole_moment = self.properties_of_interest["dipole_moment"][idx] nnp_input = NNPInput( atomic_numbers=atomic_numbers, positions=positions, - pair_list=pair_list, - total_charge=total_charge, + pair_list=self._set_pairlist(idx), + per_system_total_charge=total_charge, atomic_subsystem_indices=torch.zeros(number_of_atoms, dtype=torch.int32), ) - metadata = Metadata( - E=E, - F=F, + per_system_energy=E, + per_atom_force=F, atomic_subsystem_counts=torch.tensor([number_of_atoms], dtype=torch.int32), atomic_subsystem_indices_referencing_dataset=torch.repeat_interleave( torch.tensor([idx], dtype=torch.int32), number_of_atoms ), number_of_atoms=number_of_atoms, + per_system_dipole_moment=dipole_moment, ) return BatchData(nnp_input, metadata) @@ -425,6 +318,8 @@ def __init__( Directory to store the files. force_download : bool, optional If set to True, the data will be downloaded even if it already exists. Default is False. + regenerate_cache : bool, optional + If set to True, the cache file will be regenerated even if it already exists. Default is False. """ self.url = url self.gz_data_file = gz_data_file @@ -529,7 +424,10 @@ def _metadata_validation(self, file_name: str, file_path: str) -> bool: self._npz_metadata["data_keys"], self.properties_of_interest ): log.warning( - f"Data keys used to generate {file_path}/{file_name} ({self._npz_metadata['data_keys']}) do not match data loader ({self.properties_of_interest}) ." + f"Data keys used to generate {file_path}/{file_name} ({self._npz_metadata['data_keys']})" + ) + log.warning( + f"do not match data loader ({self.properties_of_interest})." ) return False @@ -593,6 +491,7 @@ def _from_hdf5(self) -> None: """ from collections import OrderedDict + from modelforge.utils.prop import PropertyUnits import h5py import tqdm @@ -632,6 +531,8 @@ def _from_hdf5(self) -> None: # value shapes: (*) single_atom_data: Dict[str, List[np.ndarray]] = OrderedDict() # value shapes: (n_atoms, *) + single_mol_data: Dict[str, List[np.ndarray]] = OrderedDict() + # value_shapes: (*) series_mol_data: Dict[str, List[np.ndarray]] = OrderedDict() # value shapes: (n_confs, *) series_atom_data: Dict[str, List[np.ndarray]] = OrderedDict() @@ -639,6 +540,7 @@ def _from_hdf5(self) -> None: # initialize each relevant value in data dicts to empty list for value in self.properties_of_interest: + value_format = hf[next(iter(hf.keys()))][value].attrs["format"] if value_format == "single_rec": single_rec_data[value] = [] @@ -652,7 +554,15 @@ def _from_hdf5(self) -> None: raise ValueError( f"Unknown format type {value_format} for property {value}" ) - + log.debug(f"Properties of Interest: {self.properties_of_interest}") + target_units = {} + for key in self._available_properties_association.keys(): + for property in self.properties_of_interest: + if key == property: + prop_name = self._available_properties_association[key] + target_units[property] = PropertyUnits[prop_name] + + log.debug(f"Properties of Interest units: {target_units}") self.atomic_subsystem_counts = [] # number of atoms in each record self.n_confs = [] # number of conformers in each record @@ -670,12 +580,11 @@ def _from_hdf5(self) -> None: value in hf[record].keys() for value in self.properties_of_interest ] - if all(property_found): # we want to exclude conformers with NaN values for any property of interest - configs_nan_by_prop: Dict[ - str, np.ndarray - ] = OrderedDict() # ndarray.size (n_configs, ) + configs_nan_by_prop: Dict[str, np.ndarray] = ( + OrderedDict() + ) # ndarray.size (n_configs, ) for value in list(series_mol_data.keys()) + list( series_atom_data.keys() ): @@ -696,13 +605,17 @@ def _from_hdf5(self) -> None: ) != 1 ): + val_temp = [ + value.shape + for value in configs_nan_by_prop.values() + ] raise ValueError( - f"Number of conformers is inconsistent across properties for record {record}" + f"Number of conformers is inconsistent across properties for record {record}: values {val_temp}" ) configs_nan = np.logical_or.reduce( list(configs_nan_by_prop.values()) - ) # boolean array of size (n_configs, ) + ) # boolean array of size (n_configsself.properties_of_interest, ) n_confs_rec = sum(~configs_nan) atomic_subsystem_counts_rec = hf[record][ @@ -717,6 +630,7 @@ def _from_hdf5(self) -> None: for value in single_atom_data.keys(): record_array = hf[record][value][()] + if record_array.shape[0] != atomic_subsystem_counts_rec: raise ValueError( f"Number of atoms for property {value} is inconsistent with other properties for record {record}" @@ -726,6 +640,14 @@ def _from_hdf5(self) -> None: for value in series_atom_data.keys(): record_array = hf[record][value][()][~configs_nan] + if "u" in hf[record][value].attrs: + units = hf[record][value].attrs["u"] + if units != "dimensionless": + record_array = Quantity(record_array, units).to( + target_units[value] + ) + record_array = record_array.magnitude + try: if ( record_array.shape[1] @@ -751,18 +673,32 @@ def _from_hdf5(self) -> None: ) for value in series_mol_data.keys(): + record_array = hf[record][value][()][~configs_nan] + if "u" in hf[record][value].attrs: + units = hf[record][value].attrs["u"] + if units != "dimensionless": + record_array = Quantity(record_array, units).to( + target_units[value] + ) + record_array = record_array.magnitude series_mol_data[value].append(record_array) for value in single_rec_data.keys(): record_array = hf[record][value][()] single_rec_data[value].append(record_array) + else: + log.warning( + f"Skipping record {record} as not all properties of interest are present." + ) # convert lists of arrays to single arrays data = OrderedDict() for value in single_atom_data.keys(): data[value] = np.concatenate(single_atom_data[value], axis=0) + for value in single_mol_data.keys(): + data[value] = np.concatenate(single_mol_data[value], axis=0) for value in series_mol_data.keys(): data[value] = np.concatenate(series_mol_data[value], axis=0) for value in series_atom_data.keys(): @@ -885,19 +821,8 @@ class DatasetFactory: Factory for creating TorchDataset instances from HDF5 data. Methods are provided to load or process data as needed, handling caching to improve efficiency. - - Examples - -------- - >>> factory = DatasetFactory() - >>> qm9_data = QM9Data() - >>> torch_dataset = factory.create_dataset(qm9_data) """ - def __init__( - self, - ) -> None: - pass - @staticmethod def _load_or_process_data( data: HDF5Dataset, @@ -971,30 +896,23 @@ def create_dataset( The resulting PyTorch-compatible dataset. """ - log.info(f"Creating {data.dataset_name} dataset") + log.info(f"Creating dataset from {data.url}") DatasetFactory._load_or_process_data(data) return TorchDataset(data.numpy_data, data._property_names) -from torch import nn from openff.units import unit +from modelforge.custom_types import DatasetType class DataModule(pl.LightningDataModule): def __init__( self, - name: Literal[ - "QM9", - "ANI1X", - "ANI2X", - "SPICE114", - "SPICE2", - "SPICE114_OPENFF", - "PhAlkEthOH", - ], + name: DatasetType, splitting_strategy: SplittingStrategy = RandomRecordSplittingStrategy(), batch_size: int = 64, remove_self_energies: bool = True, + shift_center_of_mass_to_origin: bool = False, atomic_self_energies: Optional[Dict[str, float]] = None, regression_ase: bool = False, force_download: bool = False, @@ -1002,6 +920,7 @@ def __init__( local_cache_dir: str = "./", regenerate_cache: bool = False, regenerate_dataset_statistic: bool = False, + regenerate_processed_cache: bool = True, ): """ Initializes adData module for PyTorch Lightning handling data preparation and loading object with the specified configuration. @@ -1012,7 +931,7 @@ def __init__( Parameters --------- - name: Literal["QM9", "ANI1X", "ANI2X", "SPICE114", "SPICE2", "SPICE114_OPENFF"] + name: Literal["QM9", "ANI1X", "ANI2X", "SPICE1", "SPICE2", "SPICE1_OPENFF"] The name of the dataset to use. splitting_strategy : SplittingStrategy, defaults to RandomRecordSplittingStrategy The strategy to use for splitting the dataset into train, test, and validation sets. . @@ -1020,6 +939,9 @@ def __init__( The batch size to use for the dataset. remove_self_energies : bool, defaults to True Whether to remove the self energies from the dataset. + shift_center_of_mass_to_origin: bool, defaults to False + Whether to shift the center of mass of the molecule to the origin. This is necessary if using the + dipole moment in the loss function. atomic_self_energies : Optional[Dict[str, float]] A dictionary mapping element names to their self energies. If not provided, the self energies will be calculated. regression_ase: bool, defaults to False @@ -1035,12 +957,16 @@ def __init__( regenerate_cache : bool, defaults to False Whether to regenerate the cache. """ + from modelforge.potential.neighbors import Pairlist + import os + super().__init__() self.name = name self.batch_size = batch_size self.splitting_strategy = splitting_strategy self.remove_self_energies = remove_self_energies + self.shift_center_of_mass_to_origin = shift_center_of_mass_to_origin self.dict_atomic_self_energies = ( atomic_self_energies # element name (e.g., 'H') maps to self energies ) @@ -1048,31 +974,63 @@ def __init__( self.force_download = force_download self.version_select = version_select self.regenerate_dataset_statistic = regenerate_dataset_statistic - self.train_dataset = None - self.test_dataset = None - self.val_dataset = None - import os + self.train_dataset: Optional[TorchDataset] = None + self.val_dataset: Optional[TorchDataset] = None + self.test_dataset: Optional[TorchDataset] = None # make sure we can handle a path with a ~ in it self.local_cache_dir = os.path.expanduser(local_cache_dir) + # create the local cache directory if it does not exist + os.makedirs(self.local_cache_dir, exist_ok=True) self.regenerate_cache = regenerate_cache - from modelforge.potential.models import Pairlist + # Use a logical OR to ensure regenerate_processed_cache is True when + # regenerate_cache is True + self.regenerate_processed_cache = ( + regenerate_processed_cache or self.regenerate_cache + ) self.pairlist = Pairlist() self.dataset_statistic_filename = ( f"{self.local_cache_dir}/{self.name}_dataset_statistic.toml" ) + self.cache_processed_dataset_filename = ( + f"{self.local_cache_dir}/{self.name}_{self.version_select}_processed.pt" + ) + self.lock_file = f"{self.cache_processed_dataset_filename}.lockfile" + + def transfer_batch_to_device(self, batch, device, dataloader_idx): + # move all tensors to the device + return batch.to_device(device) + @lock_with_attribute("lock_file") def prepare_data( self, ) -> None: """ - Prepares the dataset for use. This method is responsible for the initial processing of the data such as calculating self energies, atomic energy statistics, and splitting. It is executed only once per node. + Prepares the dataset for use. This method is responsible for the initial + processing of the data such as calculating self energies, atomic energy + statistics, and splitting. It is executed only once per node. """ + # check if there is a filelock present, if so, wait until it is removed + + # if the dataset has already been processed, skip this step + if ( + os.path.exists(self.cache_processed_dataset_filename) + and not self.regenerate_processed_cache + ): + if not os.path.exists(self.dataset_statistic_filename): + raise FileNotFoundError( + f"Dataset statistics file {self.dataset_statistic_filename} not found. Please regenerate the cache." + ) + log.info( + f'Processed dataset already exists: {self.cache_processed_dataset_filename}. Skipping "prepare_data" step.' + ) + return None + + # if the dataset is not already processed, process it from modelforge.dataset import _ImplementedDatasets - import toml - dataset_class = _ImplementedDatasets.get_dataset_class(self.name) + dataset_class = _ImplementedDatasets.get_dataset_class(str(self.name)) dataset = dataset_class( force_download=self.force_download, version_select=self.version_select, @@ -1080,7 +1038,6 @@ def prepare_data( regenerate_cache=self.regenerate_cache, ) torch_dataset = self._create_torch_dataset(dataset) - # if dataset statistics is present load it from disk if ( os.path.exists(self.dataset_statistic_filename) @@ -1218,16 +1175,16 @@ def _calculate_atomic_self_energies( def _cache_dataset(self, torch_dataset): """Cache the dataset and its statistics using PyTorch's serialization.""" - torch.save(torch_dataset, "torch_dataset.pt") - # sleep for 1 second to make sure that the dataset was written to disk + torch.save(torch_dataset, self.cache_processed_dataset_filename) + # sleep for 5 second to make sure that the dataset was written to disk import time - time.sleep(1) + time.sleep(5) def setup(self, stage: Optional[str] = None) -> None: """Sets up datasets for the train, validation, and test stages based on the stage argument.""" - self.torch_dataset = torch.load("torch_dataset.pt") + self.torch_dataset = torch.load(self.cache_processed_dataset_filename) ( self.train_dataset, self.val_dataset, @@ -1277,7 +1234,7 @@ def _per_datapoint_operations( from tqdm import tqdm # remove the self energies if requested - log.info("Precalculating pairlist for dataset") + log.info("Performing per datapoint operations in the dataset dataset") if self.remove_self_energies: log.info("Removing self energies from the dataset") @@ -1294,10 +1251,41 @@ def _per_datapoint_operations( dataset[i] = {"E": dataset.properties_of_interest["E"][i] - energy} + if self.shift_center_of_mass_to_origin: + log.info("Shifting the center of mass of each molecule to the origin.") + from openff.units.elements import MASSES + + for i in tqdm(range(len(dataset)), desc="Process dataset"): + start_idx = dataset.single_atom_start_idxs_by_conf[i] + end_idx = dataset.single_atom_end_idxs_by_conf[i] + + atomic_masses = torch.Tensor( + [ + MASSES[atomic_number].m + for atomic_number in dataset.properties_of_interest[ + "atomic_numbers" + ][start_idx:end_idx].tolist() + ] + ) + molecule_mass = torch.sum(atomic_masses) + + start_idx_mol = dataset.series_atom_start_idxs_by_conf[i] + end_idx_mol = dataset.series_atom_start_idxs_by_conf[i + 1] + + positions = dataset.properties_of_interest["positions"][ + start_idx_mol:end_idx_mol + ] + center_of_mass = ( + torch.einsum("i, ij->j", atomic_masses, positions) / molecule_mass + ) + dataset.properties_of_interest["positions"][ + start_idx_mol:end_idx_mol + ] -= center_of_mass + from torch.utils.data import DataLoader all_pairs = [] - n_pairs_per_molecule_list = [torch.tensor([0], dtype=torch.int16)] + n_pairs_per_system_list = [torch.tensor([0], dtype=torch.int16)] for batch in tqdm( DataLoader( @@ -1318,9 +1306,9 @@ def _per_datapoint_operations( batch.nnp_input.atomic_subsystem_indices.to("cpu") ) all_pairs.append(torch.from_numpy(pairs_batch)) - n_pairs_per_molecule_list.append(torch.from_numpy(n_pairs_batch)) + n_pairs_per_system_list.append(torch.from_numpy(n_pairs_batch)) - nr_of_pairs = torch.cat(n_pairs_per_molecule_list, dim=0) + nr_of_pairs = torch.cat(n_pairs_per_system_list, dim=0) nr_of_pairs_in_dataset = torch.cumsum(nr_of_pairs, dim=0, dtype=torch.int64) # Determine N (number of tensors) and K (maximum M) @@ -1363,7 +1351,7 @@ def val_dataloader(self, num_workers: int = 4) -> DataLoader: num_workers=num_workers, ) - def test_dataloader(self) -> DataLoader: + def test_dataloader(self, num_workers: int = 4) -> DataLoader: """ Create a DataLoader for the test dataset. @@ -1373,21 +1361,34 @@ def test_dataloader(self) -> DataLoader: DataLoader containing the test dataset. """ return DataLoader( - self.test_dataset, batch_size=self.batch_size, collate_fn=collate_conformers + self.test_dataset, + batch_size=self.batch_size, + collate_fn=collate_conformers, + num_workers=num_workers, ) -from typing import Tuple +def collate_conformers(conf_list: List[BatchData]) -> BatchData: + """ + Collate a list of BatchData instances into a single BatchData instance. + Parameters + ---------- + conf_list : List[BatchData] + List of BatchData instances. -def collate_conformers(conf_list: List[BatchData]) -> BatchData: - """Collate a list of BatchData instances with one conformer each into a single BatchData instance.""" + Returns + ------- + BatchData + Collated batch data. + """ atomic_numbers_list = [] positions_list = [] total_charge_list = [] E_list = [] # total energy F_list = [] # forces ij_list = [] + dipole_moment_list = [] atomic_subsystem_counts_list = [] atomic_subsystem_indices_referencing_dataset_list = [] @@ -1399,7 +1400,7 @@ def collate_conformers(conf_list: List[BatchData]) -> BatchData: else False ) - for idx, conf in enumerate(conf_list): + for conf in conf_list: if pair_list_present: ## pairlist # generate pairlist without padded values @@ -1410,9 +1411,10 @@ def collate_conformers(conf_list: List[BatchData]) -> BatchData: atomic_numbers_list.append(conf.nnp_input.atomic_numbers) positions_list.append(conf.nnp_input.positions) - total_charge_list.append(conf.nnp_input.total_charge) - E_list.append(conf.metadata.E) - F_list.append(conf.metadata.F) + total_charge_list.append(conf.nnp_input.per_system_total_charge) + dipole_moment_list.append(conf.metadata.per_system_dipole_moment) + E_list.append(conf.metadata.per_system_energy) + F_list.append(conf.metadata.per_atom_force) atomic_subsystem_counts_list.append(conf.metadata.atomic_subsystem_counts) atomic_subsystem_indices_referencing_dataset_list.append( conf.metadata.atomic_subsystem_indices_referencing_dataset @@ -1426,9 +1428,10 @@ def collate_conformers(conf_list: List[BatchData]) -> BatchData: atomic_subsystem_indices_referencing_dataset_list ) atomic_numbers = torch.cat(atomic_numbers_list) - total_charge = torch.cat(total_charge_list) + total_charge = torch.stack(total_charge_list) positions = torch.cat(positions_list).requires_grad_(True) F = torch.cat(F_list).to(torch.float64) + dipole_moment = torch.stack(dipole_moment_list).to(torch.float64) E = torch.stack(E_list) if pair_list_present: IJ_cat = torch.cat(ij_list, dim=1).to(torch.int64) @@ -1438,15 +1441,89 @@ def collate_conformers(conf_list: List[BatchData]) -> BatchData: nnp_input = NNPInput( atomic_numbers=atomic_numbers, positions=positions, - total_charge=total_charge, + per_system_total_charge=total_charge, atomic_subsystem_indices=atomic_subsystem_indices, pair_list=IJ_cat, ) metadata = Metadata( - E=E, - F=F, + per_system_energy=E, + per_atom_force=F, atomic_subsystem_counts=atomic_subsystem_counts, atomic_subsystem_indices_referencing_dataset=atomic_subsystem_indices_referencing_dataset, number_of_atoms=atomic_numbers.numel(), + per_system_dipole_moment=dipole_moment, ) return BatchData(nnp_input, metadata) + + +from modelforge.dataset.dataset import DatasetFactory +from modelforge.dataset.utils import ( + FirstComeFirstServeSplittingStrategy, + SplittingStrategy, +) + + +def initialize_datamodule( + dataset_name: str, + version_select: str = "nc_1000_v0", + batch_size: int = 64, + splitting_strategy: SplittingStrategy = FirstComeFirstServeSplittingStrategy(), + remove_self_energies: bool = True, + shift_center_of_mass_to_origin: bool = False, + regression_ase: bool = False, + regenerate_dataset_statistic: bool = False, + local_cache_dir="./", +) -> DataModule: + """ + Initialize a dataset for a given mode. + """ + + data_module = DataModule( + dataset_name, + splitting_strategy=splitting_strategy, + batch_size=batch_size, + version_select=version_select, + remove_self_energies=remove_self_energies, + shift_center_of_mass_to_origin=shift_center_of_mass_to_origin, + regression_ase=regression_ase, + regenerate_dataset_statistic=regenerate_dataset_statistic, + local_cache_dir=local_cache_dir, + ) + data_module.prepare_data() + data_module.setup() + return data_module + + +def single_batch(batch_size: int = 64, dataset_name="QM9", local_cache_dir="./"): + """ + Utility function to create a single batch of data for testing. + """ + data_module = initialize_datamodule( + dataset_name=dataset_name, + batch_size=batch_size, + version_select="nc_1000_v0", + local_cache_dir=local_cache_dir, + ) + return next(iter(data_module.train_dataloader(shuffle=False))) + + +def initialize_dataset( + dataset_name: str, + local_cache_dir: str, + versions_select: str = "nc_1000_v0", + force_download: bool = False, +) -> DataModule: + """ + Initialize a dataset for a given mode. + """ + from modelforge.dataset import _ImplementedDatasets + + factory = DatasetFactory() + data = _ImplementedDatasets.get_dataset_class(dataset_name)( + local_cache_dir=local_cache_dir, + version_select=versions_select, + force_download=force_download, + ) + dataset = factory.create_dataset(data) + + return dataset diff --git a/modelforge/dataset/phalkethoh.py b/modelforge/dataset/phalkethoh.py index dbf435d1..f2b011bf 100644 --- a/modelforge/dataset/phalkethoh.py +++ b/modelforge/dataset/phalkethoh.py @@ -1,3 +1,7 @@ +""" +Data class for handling OpenFF Sandbox CHO PhAlkEthOH v1.0 dataset. +""" + from typing import List from .dataset import HDF5Dataset @@ -50,6 +54,16 @@ class PhAlkEthOHDataset(HDF5Dataset): "total_charge", ] # All properties within the datafile, aside from SMILES/inchi. + # Mapping of available properties to the associated PropertyNames + _available_properties_association = { + "geometry": "positions", + "atomic_numbers": "atomic_numbers", + "dft_total_energy": "E", + "dft_total_force": "F", + "scf_dipole": "dipole_moment", + "total_charge": "total_charge", + } + def __init__( self, dataset_name: str = "PhAlkEthOH", @@ -88,6 +102,7 @@ def __init__( "dft_total_energy", "dft_total_force", "total_charge", + "scf_dipole", ] # NOTE: Default values self._properties_of_interest = _default_properties_of_interest diff --git a/modelforge/dataset/qm9.py b/modelforge/dataset/qm9.py index 34cf840a..85312d4a 100644 --- a/modelforge/dataset/qm9.py +++ b/modelforge/dataset/qm9.py @@ -1,3 +1,7 @@ +""" +Data class for handling QM9 data. +""" + from typing import List from .dataset import HDF5Dataset @@ -37,9 +41,10 @@ class QM9Dataset(HDF5Dataset): _property_names = PropertyNames( atomic_numbers="atomic_numbers", positions="geometry", - E="internal_energy_at_0K", # Q="charges" + E="internal_energy_at_0K", ) + # for simplicity, commenting out those properties that are cannot be used in our current implementation _available_properties = [ "geometry", "atomic_numbers", @@ -47,20 +52,34 @@ class QM9Dataset(HDF5Dataset): "internal_energy_at_298.15K", "enthalpy_at_298.15K", "free_energy_at_298.15K", - "heat_capacity_at_298.15K", + # "heat_capacity_at_298.15K", "zero_point_vibrational_energy", - "electronic_spatial_extent", + # "electronic_spatial_extent", "lumo-homo_gap", "energy_of_homo", "energy_of_lumo", - "rotational_constant_A", - "rotational_constant_B", - "rotational_constant_C", + # "rotational_constant_A", + # "rotational_constant_B", + # "rotational_constant_C", "dipole_moment", - "isotropic_polarizability", - "charges", + # "isotropic_polarizability", + # "charges", ] # All properties within the datafile, aside from SMILES/inchi. + _available_properties_association = { + "geometry": "positions", + "atomic_numbers": "atomic_numbers", + "internal_energy_at_0K": "E", + "internal_energy_at_298.15K": "E", + "enthalpy_at_298.15K": "E", + "free_energy_at_298.15K": "E", + "zero_point_vibrational_energy": "E", + "lumo-homo_gap": "E", + "energy_of_homo": "E", + "energy_of_lumo": "E", + "dipole_moment": "dipole_moment", + } + def __init__( self, dataset_name: str = "QM9", @@ -97,7 +116,7 @@ def __init__( "geometry", "atomic_numbers", "internal_energy_at_0K", - "charges", + "dipole_moment", ] # NOTE: Default values self._properties_of_interest = _default_properties_of_interest @@ -241,7 +260,3 @@ def _download(self) -> None: length=self.gz_data_file["length"], force_download=self.force_download, ) - # from modelforge.dataset.utils import _download_from_url - # - # url = self.test_url if self.for_unit_testing else self.full_url - # _download_from_url(url, self.raw_data_file) diff --git a/modelforge/dataset/spice1.py b/modelforge/dataset/spice1.py index 38ea4147..15d48aa0 100644 --- a/modelforge/dataset/spice1.py +++ b/modelforge/dataset/spice1.py @@ -1,3 +1,7 @@ +""" +SPICE1Dataset class for handling the SPICE 1 dataset. +""" + from typing import List from .dataset import HDF5Dataset @@ -50,24 +54,38 @@ class SPICE1Dataset(HDF5Dataset): positions="geometry", E="dft_total_energy", F="dft_total_force", - total_charge="mbis_charges", + total_charge="total_charge", + dipole_moment="scf_dipole", ) + # note for simplicifty, commenting out those properties that cannot be used in the current implementation _available_properties = [ "geometry", "atomic_numbers", "dft_total_energy", "dft_total_force", - "mbis_charges", - "mbis_multipoles", - "mbis_octopoles", + # "mbis_charges", + # "mbis_multipoles", + # "mbis_octopoles", "formation_energy", "scf_dipole", - "scf_quadrupole", + # "scf_quadrupole", "total_charge", "reference_energy", ] # All properties within the datafile, aside from SMILES/inchi. + # Mapping of available properties to the associated PropertyNames + _available_properties_association = { + "geometry": "positions", + "atomic_numbers": "atomic_numbers", + "dft_total_energy": "E", + "dft_total_force": "F", + "formation_energy": "E", + "scf_dipole": "dipole_moment", + "total_charge": "total_charge", + "reference_energy": "E", + } + def __init__( self, dataset_name: str = "SPICE1", @@ -105,7 +123,8 @@ def __init__( "atomic_numbers", "dft_total_energy", "dft_total_force", - "mbis_charges", + "total_charge", + "scf_dipole", ] # NOTE: Default values self._properties_of_interest = _default_properties_of_interest diff --git a/modelforge/dataset/spice1openff.py b/modelforge/dataset/spice1openff.py index 1902b7dc..d42fca6d 100644 --- a/modelforge/dataset/spice1openff.py +++ b/modelforge/dataset/spice1openff.py @@ -1,3 +1,7 @@ +""" +Data class for handling SPICE 1 dataset at the OpenForceField level of theory. +""" + from typing import List from .dataset import HDF5Dataset @@ -64,21 +68,35 @@ class SPICE1OpenFFDataset(HDF5Dataset): positions="geometry", E="dft_total_energy", F="dft_total_force", - total_charge="mbis_charges", + total_charge="total_charge", ) + # commenting out those properties that are cannot be used in our current implementation _available_properties = [ "geometry", "atomic_numbers", "dft_total_energy", "dft_total_force", - "mbis_charges", + # "mbis_charges", "formation_energy", "scf_dipole", "total_charge", "reference_energy", ] # All properties within the datafile, aside from SMILES/inchi. + # note these are simply mapping to a property with equivalent units in the dataset + # not implying we would want to use this for training + _available_properties_association = { + "geometry": "positions", + "atomic_numbers": "atomic_numbers", + "dft_total_energy": "E", + "dft_total_force": "F", + "formation_energy": "E", + "scf_dipole": "dipole_moment", + "total_charge": "total_charge", + "reference_energy": "E", + } + def __init__( self, dataset_name: str = "SPICE1_openff", @@ -116,7 +134,7 @@ def __init__( "atomic_numbers", "dft_total_energy", "dft_total_force", - "mbis_charges", + "total_charge", ] # NOTE: Default values self._properties_of_interest = _default_properties_of_interest diff --git a/modelforge/dataset/spice2.py b/modelforge/dataset/spice2.py index 77ed69bb..9b104729 100644 --- a/modelforge/dataset/spice2.py +++ b/modelforge/dataset/spice2.py @@ -1,3 +1,7 @@ +""" +Data class for handling SPICE 2 dataset. +""" + from typing import List from .dataset import HDF5Dataset @@ -78,21 +82,34 @@ class SPICE2Dataset(HDF5Dataset): positions="geometry", E="dft_total_energy", F="dft_total_force", - total_charge="mbis_charges", + total_charge="total_charge", + dipole_moment="scf_dipole", ) + # for simplicifty, commenting out those properties that are cannot be used in our current implementation _available_properties = [ "geometry", "atomic_numbers", "dft_total_energy", "dft_total_force", - "mbis_charges", + # "mbis_charges", "formation_energy", "scf_dipole", "total_charge", "reference_energy", ] # All properties within the datafile, aside from SMILES/inchi. + _available_properties_association = { + "geometry": "positions", + "atomic_numbers": "atomic_numbers", + "dft_total_energy": "E", + "dft_total_force": "F", + "formation_energy": "E", + "scf_dipole": "dipole_moment", + "total_charge": "total_charge", + "reference_energy": "E", + } + def __init__( self, dataset_name: str = "SPICE2", @@ -130,7 +147,8 @@ def __init__( "atomic_numbers", "dft_total_energy", "dft_total_force", - "mbis_charges", + "total_charge", + "scf_dipole", ] # NOTE: Default values self._properties_of_interest = _default_properties_of_interest diff --git a/modelforge/dataset/utils.py b/modelforge/dataset/utils.py index 49df52dd..6b915fce 100644 --- a/modelforge/dataset/utils.py +++ b/modelforge/dataset/utils.py @@ -1,3 +1,7 @@ +""" +Utility functions for dataset handling. +""" + from __future__ import annotations import warnings @@ -117,7 +121,7 @@ def calculate_mean_and_variance( log.info("Calculating mean and variance of atomic energies") for batch_data in tqdm.tqdm(dataloader): E_scaled = ( - batch_data.metadata.E + batch_data.metadata.per_system_energy / batch_data.metadata.atomic_subsystem_counts.view(-1, 1) ) online_estimator.update(E_scaled) @@ -143,9 +147,9 @@ def _calculate_self_energies(torch_dataset, collate_fn) -> Dict[str, unit.Quanti # Determine the size of the counts tensor num_molecules = torch_dataset.number_of_records # Determine up to which Z we detect elements - max_atomic_number = 100 + maximum_atomic_number = 100 # Initialize the counts tensor - counts = torch.zeros(num_molecules, max_atomic_number + 1, dtype=torch.int16) + counts = torch.zeros(num_molecules, maximum_atomic_number + 1, dtype=torch.int16) # save energies in list energy_array = torch.zeros(torch_dataset.number_of_records, dtype=torch.float64) # for filling in the element count matrix @@ -158,9 +162,8 @@ def _calculate_self_energies(torch_dataset, collate_fn) -> Dict[str, unit.Quanti for batch in DataLoader( torch_dataset, batch_size=batch_size, collate_fn=collate_fn ): - a = 7 energies, atomic_numbers, molecules_id = ( - batch.metadata.E.squeeze(), + batch.metadata.per_system_energy.squeeze(), batch.nnp_input.atomic_numbers.squeeze(-1).to(torch.int64), batch.nnp_input.atomic_subsystem_indices.to(torch.int16), ) @@ -324,6 +327,7 @@ class RandomRecordSplittingStrategy(SplittingStrategy): Strategy to split a dataset randomly, keeping all conformers in a record in the same split. """ + def __init__(self, seed: int = 42, split: List[float] = [0.8, 0.1, 0.1]): """ This strategy splits a dataset randomly based on provided ratios for training, validation, @@ -507,7 +511,6 @@ def split(self, dataset: "TorchDataset") -> Tuple[Subset, Subset, Subset]: return (train_d, val_d, test_d) - REGISTERED_SPLITTING_STRATEGIES = { "first_come_first_serve": FirstComeFirstServeSplittingStrategy, "random_record_splitting_strategy": RandomRecordSplittingStrategy, diff --git a/modelforge/dataset/yaml_files/PhAlkEthOH.yaml b/modelforge/dataset/yaml_files/PhAlkEthOH.yaml index 06ca005e..7d8d186a 100644 --- a/modelforge/dataset/yaml_files/PhAlkEthOH.yaml +++ b/modelforge/dataset/yaml_files/PhAlkEthOH.yaml @@ -1,6 +1,66 @@ dataset: PhAlkEthOH -latest: full_dataset_v0 -latest_test: nc_1000_v0 +latest: full_dataset_v1 +latest_test: nc_1000_v1 +full_dataset_v1: + version: 1 + doi: 10.5281/zenodo.13450735 + notes: removes high force conformers + gz_data_file: + length: 3300668359 + md5: b051af374f3233e2925f7a1b96707772 + name: PhAlkEthOH_dataset_v1.hdf5.gz + hdf5_data_file: + md5: f5d9dccb8e79a51892b671108bc57bde + name: PhAlkEthOH_dataset_v1.hdf5 + processed_data_file: + md5: null + name: PhAlkEthOH_dataset_v1_processed.npz + url: https://zenodo.org/records/13450735/files/PhAlkEthOH_openff_dataset_v1.hdf5.gz +nc_1000_v1: + version: 1 + doi: 10.5281/zenodo.13560343 + notes: removes high force conformers, 1000 conformers, max 10 per molecule + gz_data_file: + length: 2702091 + md5: 76b421802bef68f858757dba41f3ea2e + name: PhAlkEthOH_dataset_v1_nc_1000.hdf5.gz + hdf5_data_file: + md5: 244eb8d1b3547b8da229fd1507fb4d4e + name: PhAlkEthOH_dataset_v1_nc_1000.hdf5 + processed_data_file: + md5: null + name: PhAlkEthOH_dataset_v1_nc_1000_processed.npz + url: https://zenodo.org/records/13560343/files/PhAlkEthOH_openff_dataset_v1_ntc_1000.hdf5.gz +full_dataset_min_v1: + version: 1 + doi: 10.5281/zenodo.13561100 + notes: removes high force configurations, only contains final optimized configuration + gz_data_file: + length: 31352642 + md5: 205b0b7bc1858b1d3745480d9a29a770 + name: PhAlkEthOH_dataset_v1_min.hdf5.gz + hdf5_data_file: + md5: 41cb40718f8872baa6c468ab08574d46 + name: PhAlkEthOH_dataset_v1_min.hdf5 + processed_data_file: + md5: null + name: PhAlkEthOH_dataset_v1_min_processed.npz + url: https://zenodo.org/records/13561100/files/PhAlkEthOH_openff_dataset_v1_min.hdf5.gz +nc_1000_min_v1: + version: 1 + doi: 10.5281/zenodo.13576458 + notes: removes high force conformers, 1000 conformers, only contains final optimized configuration + gz_data_file: + length: 3476870 + md5: 7261f4738efd4bf8409268961837ba78 + name: PhAlkEthOH_dataset_v1_nc_1000_min.hdf5.gz + hdf5_data_file: + md5: 5d347a78c6c3b45531870a05d5aab77e + name: PhAlkEthOH_dataset_v1_nc_1000_min.hdf5 + processed_data_file: + md5: null + name: PhAlkEthOH_dataset_v1_nc_1000_min_processed.npz + url: https://zenodo.org/records/13576458/files/PhAlkEthOH_openff_dataset_v1_ntc_1000_min.hdf5.gz full_dataset_v0: version: 0 doi: 10.5281/zenodo.12174233 diff --git a/modelforge/dataset/yaml_files/__init__.py b/modelforge/dataset/yaml_files/__init__.py index e69de29b..652721b8 100644 --- a/modelforge/dataset/yaml_files/__init__.py +++ b/modelforge/dataset/yaml_files/__init__.py @@ -0,0 +1 @@ +"""Configuration files for fetching/validating dataset versions.""" diff --git a/modelforge/dataset/yaml_files/spice1.yaml b/modelforge/dataset/yaml_files/spice1.yaml index ac88b3f4..06b8a9da 100644 --- a/modelforge/dataset/yaml_files/spice1.yaml +++ b/modelforge/dataset/yaml_files/spice1.yaml @@ -1,6 +1,63 @@ dataset: spice1 -latest: full_dataset_v0 -latest_test: nc_1000_v0 +latest: full_dataset_v1 +latest_test: nc_1000_v1 + +full_dataset_v1: + version: 1 + doi: 10.5281/zenodo.13883667 + gz_data_file: + length: 11218036423 + md5: eb94c7d8d8bf06cd3c6f9bf13be4c364 + name: SPICE114_dataset_v1.hdf5.gz + hdf5_data_file: + md5: 4bb4e8f638b86096ee2dd83685d8d494 + name: SPICE114_dataset_v1.hdf5 + processed_data_file: + md5: null + name: SPICE114_dataset_v1_processed.npz + url: https://zenodo.org/records/13883667/files/spice_114_dataset_v1.hdf5.gz +nc_1000_v1: + version: 1 + doi: 10.5281/zenodo.13883550 + gz_data_file: + length: 15165095 + md5: dabdff24358c303f82f04c6599ac53d0 + name: SPICE114_dataset_v1_nc_1000.hdf5.gz + hdf5_data_file: + md5: 33d2d04bb14d59fc54ad538e1083ea9e + name: SPICE114_dataset_v1_nc_1000.hdf5 + processed_data_file: + md5: null + name: SPICE114_dataset_v1_nc_1000_processed.npz + url: https://zenodo.org/records/13883550/files/spice_114_dataset_v1_ntc_1000.hdf5.gz +full_dataset_v1_HCNOFClS: + version: 1 + doi: 10.5281/zenodo.13883617 + gz_data_file: + length: 10046842576 + md5: d17383d4c2cf51aec790036c9fe98c04 + name: SPICE114_dataset_v1_HCNOFClS.hdf5.gz + hdf5_data_file: + md5: aae832f44a748a4bc9932bbed6ee4c7e + name: SPICE114_dataset_v1_HCNOFClS.hdf5 + processed_data_file: + md5: null + name: SPICE114_dataset_v1_HCNOFClS_processed.npz + url: https://zenodo.org/records/13883617/files/spice_114_dataset_v1_HCNOFClS.hdf5.gz +nc_1000_v1_HCNOFClS: + version: 1 + doi: 10.5281/zenodo.13883112 + gz_data_file: + length: 14782035 + md5: a9c327a5e86dae548c93918cd4f5a821 + name: SPICE114_dataset_v1_nc_1000_HCNOFClS.hdf5.gz + hdf5_data_file: + md5: dc143bd509287f5f015c46058430a8fd + name: SPICE114_dataset_v1_nc_1000_HCNOFClS.hdf5 + processed_data_file: + md5: null + name: SPICE114_dataset_v1_nc_1000_HCNOFClS_processed.npz + url: https://zenodo.org/records/13883112/files/spice_114_dataset_v1_ntc_1000_HCNOFClS.hdf5.gz full_dataset_v0: version: 0 doi: 10.5281/zenodo.11583341 @@ -27,7 +84,7 @@ nc_1000_v0: name: SPICE114_dataset_v0_nc_1000.hdf5 processed_data_file: md5: null - name: SPICE114_dataset_nc_1000_processed.npz + name: SPICE114_dataset_v0_nc_1000_processed.npz url: https://zenodo.org/records/11607708/files/spice_114_dataset_v0_ntc_1000.hdf5.gz full_dataset_v0_HCNOFClS: version: 0 diff --git a/modelforge/dataset/yaml_files/spice1openff.yaml b/modelforge/dataset/yaml_files/spice1openff.yaml index 26b0a670..37406c7d 100644 --- a/modelforge/dataset/yaml_files/spice1openff.yaml +++ b/modelforge/dataset/yaml_files/spice1openff.yaml @@ -1,6 +1,62 @@ dataset: spice1openff -latest: full_dataset_v0 -latest_test: nc_1000_v0 +latest: full_dataset_v1 +latest_test: nc_1000_v1 +full_dataset_v1: + version: 1 + gz_data_file: + doi: 10.5281/zenodo.13883727 + length: 2545886523 + md5: e7dc6672f629aa8c539307bfabc17b61 + name: SPICE1_OpenFF_dataset_v1.hdf5.gz + hdf5_data_file: + md5: b4987ae2d362e160162308883bc2211d + name: SPICE1_OpenFF_dataset_v1.hdf5 + processed_data_file: + md5: null + name: SPICE1_OpenFF_dataset_v1_processed.npz + url: https://zenodo.org/records/13883727/files/spice_114_openff_dataset_v1.hdf5.gz +nc_1000_v1: + version: 1 + doi: 10.5281/zenodo.13883722 + gz_data_file: + length: 2510244 + md5: 4e5f1bf4f347aff5282888e7e0ef40ae + name: SPICE1_OpenFF_dataset_v1_nc_1000.hdf5.gz + hdf5_data_file: + md5: 0abe3d918195c5287ce093951fbcac50 + name: SPICE1_OpenFF_dataset_v1_nc_1000.hdf5 + processed_data_file: + md5: null + name: SPICE1_OpenFF_dataset_v1_nc_1000_processed.npz + url: https://zenodo.org/records/13883722/files/spice_114_openff_dataset_v1_ntc_1000.hdf5.gz +full_dataset_v1_HCNOFClS: + version: 1 + doi: 10.5281/zenodo.13883653 + gz_data_file: + length: 2306536068 + md5: 41e3da36f70cd459bcfc870eba4bfbdb + name: SPICE1_OpenFF_dataset_v1_HCNOFClS.hdf5.gz + hdf5_data_file: + md5: e50d98b13eb6d09ecead0ca57a56f7ec + name: SPICE1_OpenFF_dataset_v1_HCNOFClS.hdf5 + processed_data_file: + md5: null + name: SPICE1_OpenFF_dataset_v1_HCNOFClS_processed.npz + url: https://zenodo.org/records/13883653/files/spice_114_openff_dataset_v1_HCNOFClS.hdf5.gz +nc_1000_v1_HCNOFClS: + version: 1 + doi: 10.5281/zenodo.13883717 + gz_data_file: + length: 2534101 + md5: 75f595abdb5920708381e7d3d0526d70 + name: SPICE1_OpenFF_dataset_v1_nc_1000_HCNOFClS.hdf5.gz + hdf5_data_file: + md5: d49097528db06cd03f21786055300d4e + name: SPICE1_OpenFF_dataset_v1_nc_1000_HCNOFClS.hdf5 + processed_data_file: + md5: null + name: SPICE1_OpenFF_dataset_v1_nc_1000_HCNOFClS_processed.npz + url: https://zenodo.org/records/13883717/files/spice_114_openff_dataset_v1_ntc_1000_HCNOFClS.hdf5.gz full_dataset_v0: version: 0 gz_data_file: diff --git a/modelforge/dataset/yaml_files/spice2.yaml b/modelforge/dataset/yaml_files/spice2.yaml index 09f518b5..b98bd477 100644 --- a/modelforge/dataset/yaml_files/spice2.yaml +++ b/modelforge/dataset/yaml_files/spice2.yaml @@ -1,6 +1,62 @@ dataset: spice2 -latest: full_dataset_v0 -latest_test: nc_1000_v0 +latest: full_dataset_v1 +latest_test: nc_1000_v1 +full_dataset_v1: + version: 1 + doi: 10.5281/zenodo.13883915 + gz_data_file: + length: 26313365249 + md5: 46ab1dcbea52dcb92fc7e65702459946 + name: SPICE2_dataset_v1.hdf5.gz + hdf5_data_file: + md5: f8cfee945d3546af76cbea45d8dabab4 + name: SPICE2_dataset_v1.hdf5 + processed_data_file: + md5: null + name: SPICE2_dataset_v1_processed.npz + url: https://zenodo.org/records/13883915/files/spice_2_dataset_v1.hdf5.gz +nc_1000_v1: + version: 1 + doi: 10.5281/zenodo.13883816 + gz_data_file: + length: 26757127 + md5: 92959e49db0023a47dedcbaa7f316e75 + name: SPICE2_dataset_v1_nc_1000.hdf5.gz + hdf5_data_file: + md5: a35cbffbf7fea6242081ace2131630b3 + name: SPICE2_dataset_v1_nc_1000.hdf5 + processed_data_file: + md5: null + name: SPICE2_dataset_v1_nc_1000_processed.npz + url: https://zenodo.org/records/13883816/files/spice_2_dataset_v1_ntc_1000.hdf5.gz +full_dataset_v1_HCNOFClS: + version: 1 + doi: 10.5281/zenodo.13883829 + gz_data_file: + length: 26313365249 + md5: 0dea7f31bb2fbb1f7e190b49530d0c51 + name: SPICE2_dataset_v1_HCNOFClS.hdf5.gz + hdf5_data_file: + md5: 78a41b098eb0999013c10b65e9a57875 + name: SPICE2_dataset_v1_HCNOFClS.hdf5 + processed_data_file: + md5: null + name: SPICE2_dataset_v1_HCNOFClS_processed.npz + url: https://zenodo.org/records/13883829/files/spice_2_dataset_v1_HCNOFClS.hdf5.gz +nc_1000_v1_HCNOFClS: + version: 1 + doi: 10.5281/zenodo.13883771 + gz_data_file: + length: 26816135 + md5: 79309a230e8481373a0843b8cbac5f70 + name: SPICE2_dataset_v1_nc_1000_HCNOFClS.hdf5.gz + hdf5_data_file: + md5: 3246dc275f098459d963af633c1666a6 + name: SPICE2_dataset_v1_nc_1000_HCNOFClS.hdf5 + processed_data_file: + md5: null + name: SPICE2_dataset_v1_nc_1000_HCNOFClS_processed.npz + url: https://zenodo.org/records/13883771/files/spice_2_dataset_v1_ntc_1000_HCNOFClS.hdf5.gz full_dataset_v0: version: 0 doi: 10.5281/zenodo.11632270 diff --git a/modelforge/jax.py b/modelforge/jax.py new file mode 100644 index 00000000..2140af6f --- /dev/null +++ b/modelforge/jax.py @@ -0,0 +1,50 @@ +from modelforge.utils.prop import NNPInput + + +def nnpinput_flatten(nnpinput: NNPInput): + # Collect all attributes into a tuple + children = ( + nnpinput.atomic_numbers, + nnpinput.positions, + nnpinput.atomic_subsystem_indices, + nnpinput.per_system_total_charge, + nnpinput.box_vectors, + nnpinput.is_periodic, + nnpinput.pair_list, + nnpinput.per_atom_partial_charge, + ) + # No auxiliary data is needed + aux_data = None + return (children, aux_data) + + +def nnpinput_unflatten(aux_data, children): + # Reconstruct the NNPInput instance from the children + return NNPInput(*children) + + +def convert_NNPInput_to_jax(nnp_input: NNPInput): + """ + Convert the NNPInput to a JAX-compatible format. + """ + from modelforge.utils.io import import_ + + convert_to_jax = import_("pytorch2jax").pytorch2jax.convert_to_jax + + nnp_input.atomic_numbers = convert_to_jax(nnp_input.atomic_numbers) + nnp_input.positions = convert_to_jax(nnp_input.positions) + nnp_input.atomic_subsystem_indices = convert_to_jax( + nnp_input.atomic_subsystem_indices + ) + nnp_input.per_system_total_charge = convert_to_jax( + nnp_input.per_system_total_charge + ) + nnp_input.box_vectors = convert_to_jax(nnp_input.box_vectors) + nnp_input.is_periodic = convert_to_jax(nnp_input.is_periodic) + + nnp_input.pair_list = convert_to_jax(nnp_input.pair_list) + nnp_input.per_atom_partial_charge = convert_to_jax( + nnp_input.per_atom_partial_charge + ) + + return nnp_input diff --git a/modelforge/potential/__init__.py b/modelforge/potential/__init__.py index 796dd50d..fc134aa8 100644 --- a/modelforge/potential/__init__.py +++ b/modelforge/potential/__init__.py @@ -1,25 +1,69 @@ -from .schnet import SchNet -from .physnet import PhysNet -from .painn import PaiNN -from .ani import ANI2x -from .sake import SAKE -from .spookynet import SpookyNet -from .utils import ( - CosineCutoff, - RadialBasisFunction, - AngularSymmetryFunction, +""" +This module contains the implemented neural network potentials and their parameters. +""" + +from enum import Enum + +from .potential import NeuralNetworkPotentialFactory +from .painn import PaiNNCore +from .parameters import ( + AimNet2Parameters, + ANI2xParameters, + PaiNNParameters, + PhysNetParameters, + SAKEParameters, + SchNetParameters, + TensorNetParameters, + SpookyNetParameters, ) from .processing import FromAtomToMoleculeReduction -from .models import NeuralNetworkPotentialFactory -from enum import Enum +from .representation import ( + CosineAttenuationFunction, + PhysNetRadialBasisFunction, + TensorNetRadialBasisFunction, + AniRadialBasisFunction, + SchnetRadialBasisFunction, +) +from .featurization import FeaturizeInput + +from .physnet import PhysNetCore +from .sake import SAKECore +from .schnet import SchNetCore +from .tensornet import TensorNetCore +from .aimnet2 import AimNet2Core +from .ani import ANI2xCore + + +class _Implemented_NNP_Parameters(Enum): + ANI2X_PARAMETERS = ANI2xParameters + SCHNET_PARAMETERS = SchNetParameters + TENSORNET_PARAMETERS = TensorNetParameters + PAINN_PARAMETERS = PaiNNParameters + PHYSNET_PARAMETERS = PhysNetParameters + SAKE_PARAMETERS = SAKEParameters + AIMNET2_PARAMETERS = AimNet2Parameters + + @classmethod + def get_neural_network_parameter_class(cls, neural_network_name: str): + try: + # Normalize the input and get the class directly from the Enum + name = neural_network_name.upper() + "_PARAMETERS" + return cls[name.upper()].value + except KeyError: + available_potentials = ", ".join([d.name for d in cls]) + raise ValueError( + f"Parameters for {neural_network_name} are not implemented. Available parameters: {available_potentials}" + ) class _Implemented_NNPs(Enum): - ANI2X = ANI2x - SCHNET = SchNet - PAINN = PaiNN - PHYSNET = PhysNet - SAKE = SAKE + SCHNET = SchNetCore + ANI2X = ANI2xCore + PHYSNET = PhysNetCore + TENSORNET = TensorNetCore + PAINN = PaiNNCore + SAKE = SAKECore + AIMNET2 = AimNet2Core SPOOKYNET = SpookyNet @classmethod @@ -28,9 +72,9 @@ def get_neural_network_class(cls, neural_network_name: str): # Normalize the input and get the class directly from the Enum return cls[neural_network_name.upper()].value except KeyError: - available_datasets = ", ".join([d.name for d in cls]) + available_potentials = ", ".join([d.name for d in cls]) raise ValueError( - f"Dataset {neural_network_name} is not implemented. Available datasets are: {available_datasets}" + f"Potential {neural_network_name} is not implemented. Available potentials are: {available_potentials}" ) @staticmethod diff --git a/modelforge/potential/aimnet2.py b/modelforge/potential/aimnet2.py new file mode 100644 index 00000000..3dc27c5f --- /dev/null +++ b/modelforge/potential/aimnet2.py @@ -0,0 +1,557 @@ +from typing import Dict, List, Tuple + +import torch +import torch.nn as nn +from loguru import logger as log + +from modelforge.potential.utils import Dense + +from modelforge.dataset.dataset import NNPInput +from modelforge.potential.neighbors import PairlistData + + +class AimNet2Core(torch.nn.Module): + def __init__( + self, + featurization: Dict[str, Dict[str, int]], + number_of_radial_basis_functions: int, + number_of_vector_features: int, + number_of_interaction_modules: int, + activation_function_parameter: Dict[str, str], + predicted_properties: List[str], + predicted_dim: List[int], + maximum_interaction_radius: float, + ) -> None: + """ + Core architecture of the AimNet2 model for molecular property + prediction. + + Parameters + ---------- + featurization : Dict[str, Dict[str, int]] + Configuration dictionary specifying feature details for atomic + embeddings. + number_of_radial_basis_functions : int + Number of radial basis functions used in the radial symmetry + function. + number_of_interaction_modules : int + Number of interaction modules in the model, determining the depth of + message passing. + activation_function_parameter : Dict[str, str] + Configuration of activation functions used across the model. + predicted_properties : List[str] + List of properties that the model is predicting (e.g., energy, + forces). + predicted_dim : List[int] + The dimensionality of each predicted property. + maximum_interaction_radius : float + The cutoff radius for atomic interactions in the model. + """ + + super().__init__() + + log.debug("Initializing the AimNet2 architecture.") + + self.activation_function = activation_function_parameter["activation_function"] + + # Initialize representation block + self.representation_module = AIMNet2Representation( + maximum_interaction_radius, + number_of_radial_basis_functions, + featurization_config=featurization, + ) + number_of_per_atom_features = int( + featurization["atomic_number"]["number_of_per_atom_features"] + ) + + self.agh = nn.Parameter( + torch.randn( + number_of_per_atom_features, # F_atom + number_of_radial_basis_functions, # G + number_of_vector_features, # H + ) + ) + # shape(nr_of_angular_symmetry_functions,nr_of_radial_symmetry_functions,nr_of_vector_features) + + # Define interaction modules for message passing + self.interaction_modules = torch.nn.ModuleList( + [ + AIMNet2InteractionModule( + number_of_per_atom_features=number_of_per_atom_features, + number_of_radial_basis_functions=number_of_radial_basis_functions, + number_of_vector_features=number_of_vector_features, + activation_function=self.activation_function, + is_first_module=(i == 0), + ) + for i in range(number_of_interaction_modules) + ] + ) + # Define output layers to calculate per-atom predictions + self.output_layers = nn.ModuleDict() + for property, dim in zip(predicted_properties, predicted_dim): + self.output_layers[property] = nn.Sequential( + Dense( + number_of_per_atom_features, + number_of_per_atom_features, + activation_function=self.activation_function, + ), + Dense( + number_of_per_atom_features, + int(dim), + ), + ) + from modelforge.potential.processing import ChargeConservation + + self.charge_conservation = ChargeConservation() + + def compute_properties( + self, + data: NNPInput, + pairlist: PairlistData, + ) -> Dict[str, torch.Tensor]: + """ + Calculate the requested properties for a given input batch. + + Parameters + ---------- + data : NNPInput + The input data for the model. + pairlist: PairlistData + The output from the pairlist module. + Returns + ------- + Dict[str, torch.Tensor] + The calculated per-atom scalar representations and atomic subsystem + indices. + """ + + rep = self.representation_module(data, pairlist) + atomic_embedding = rep["atomic_embedding"] + r_ij, d_ij, f_ij, f_cutoff = ( + pairlist.r_ij, + pairlist.d_ij, + rep["f_ij"], + rep["f_cutoff"], + ) + # Scalar Gaussian expansion for radial terms + gs = f_ij * f_cutoff # Shape: (number_of_pairs, G) + # Unit direction vectors + u_ij = r_ij / d_ij + # Compute gv with shape (number_of_pairs, 3, G) + gv = u_ij.unsqueeze(-1) * gs.unsqueeze(1) # Broadcasting over G + + # Atomic embedding "a" Eqn. (3) + partial_charges = torch.zeros( + (atomic_embedding.shape[0], 1), device=atomic_embedding.device + ) + + # Perform message passing using interaction modules + for i, interaction in enumerate(self.interaction_modules): + + delta_a, delta_q, f = interaction( + atomic_embedding, + partial_charges, + pairlist.pair_indices, + gs, + gv, + self.agh, + ) + + # Update atomic embeddings + atomic_embedding = atomic_embedding + delta_a + + # Apply scaling factor `f` to `delta_q` + scaled_delta_q = f * delta_q + + # Update partial charges + if i == 0: + partial_charges = scaled_delta_q # Initialize charges + else: + partial_charges = partial_charges + scaled_delta_q # Incremental update + + partial_charges = self.charge_conservation( + { + "per_atom_charge": partial_charges, + "per_system_total_charge": data.per_system_total_charge.to( + dtype=torch.float32 + ), + "atomic_subsystem_indices": data.atomic_subsystem_indices.to( + dtype=torch.int64 + ), + } + )["per_atom_charge"] + + # check that none of the tensors are NaN + if torch.isnan(atomic_embedding).any(): + raise ValueError("NaN values detected in atomic embeddings.") + if torch.isnan(partial_charges).any(): + raise ValueError("NaN values detected in partial charges.") + + return { + "per_atom_scalar_representation": atomic_embedding, + "atomic_subsystem_indices": data.atomic_subsystem_indices, + "atomic_numbers": data.atomic_numbers, + } + + def forward( + self, + data: NNPInput, + pairlist_output: PairlistData, + ) -> Dict[str, torch.Tensor]: + """ + Implements the forward pass through the network. + + Parameters + ---------- + data : NNPInput + Contains input data for the batch obtained directly from the + dataset, including atomic numbers, positions, and other relevant + fields. + pairlist_output : PairListOutputs + Contains the indices for the selected pairs and their associated + distances and displacement vectors. + + Returns + ------- + Dict[str, torch.Tensor] + The calculated per-atom properties and other properties from the + forward pass. + """ + # perform the forward pass implemented in the subclass + results = self.compute_properties(data, pairlist_output) + atomic_embedding = results["per_atom_scalar_representation"] + + # Compute all specified outputs + for output_name, output_layer in self.output_layers.items(): + output = output_layer(atomic_embedding) + results[output_name] = output + + return results + + +import torch +import torch.nn as nn +from torch import Tensor +from typing import Tuple + + +class AIMNet2InteractionModule(nn.Module): + def __init__( + self, + number_of_per_atom_features: int, + number_of_radial_basis_functions: int, + number_of_vector_features: int, + activation_function: nn.Module, + is_first_module: bool = False, + ): + super().__init__() + self.is_first_module = is_first_module + self.number_of_per_atom_features = number_of_per_atom_features + self.number_of_vector_features = number_of_vector_features + self.gs_to_fatom = Dense( + number_of_radial_basis_functions, number_of_per_atom_features, bias=False + ) + + if not self.is_first_module: + self.number_of_input_features = ( + number_of_per_atom_features # radial_contributions_emb + + number_of_vector_features # vector_contributions_emb + + number_of_per_atom_features # radial_contributions_charge + + number_of_vector_features # vector_contributions_charge + ) + else: + self.number_of_input_features = ( + number_of_per_atom_features # radial_contributions_emb + + number_of_vector_features # vector_contributions_emb + ) + + # Single MLP producing combined outputs + self.mlp = nn.Sequential( + Dense( + in_features=self.number_of_input_features, + out_features=128, + activation_function=activation_function, + ), + Dense( + in_features=128, + out_features=128, + activation_function=activation_function, + ), + Dense( + in_features=128, + out_features=number_of_per_atom_features + 2, # delta_q, f, delta_a + ), + ) + + def calculate_radial_contributions( + self, + gs: Tensor, + a_j: Tensor, + number_of_atoms: int, + idx_j: Tensor, + ) -> Tensor: + """ + Compute radial contributions for each atom based on pair interactions. + + Parameters + ---------- + gs : Tensor + Radial symmetry functions with shape (number_of_pairs, G). + a_j : Tensor + Atomic features for each pair with shape (number_of_pairs, F_atom) or (number_of_pairs, 1). + number_of_atoms : int + Total number of atoms in the system. + idx_j : Tensor + Indices mapping each pair to an atom, with shape (number_of_pairs,). + + Returns + ------- + Tensor + Radial contributions aggregated per atom, with shape (number_of_atoms, F_atom). + """ + # Map gs to shape (number_of_pairs, F_atom) + mapped_gs = self.gs_to_fatom(gs) # Shape: (number_of_pairs, F_atom) + + # Compute avf_s using element-wise multiplication + avf_s = a_j * mapped_gs # Shape: (number_of_pairs, F_atom) + + # Initialize tensor to accumulate radial contributions + radial_contributions = torch.zeros( + (number_of_atoms, avf_s.shape[-1]), + device=avf_s.device, + dtype=avf_s.dtype, + ) + # Aggregate per atom + radial_contributions.index_add_(0, idx_j, avf_s) + + return radial_contributions + + def calculate_vector_contributions( + self, + gv: Tensor, + a_j: Tensor, + idx_j: Tensor, + agh: Tensor, + number_of_atoms: int, + device: torch.device, + ) -> Tensor: + """ + Compute vector (angular) contributions for each atom based on pair interactions. + + Parameters + ---------- + gv : Tensor + Vector symmetry functions with shape (number_of_pairs, 3, G). + a_j : Tensor + Atomic features for each pair with shape (number_of_pairs, F_atom). + idx_j : Tensor + Indices mapping each pair to an atom, with shape (number_of_pairs,). + agh : Tensor + Transformation tensor with shape (F_atom, G, H). + number_of_atoms : int + Total number of atoms in the system. + device : torch.device + The device to perform computations on. + + Returns + ------- + Tensor + Vector contributions aggregated per atom, with shape (number_of_atoms, H). + """ + # Compute per-pair vector contributions + # avf_v: (number_of_pairs, H, 3) + avf_v = torch.einsum("pa, pdg, agh -> phd", a_j, gv, agh) + + # Initialize tensor to accumulate vector contributions per atom + avf_v_sum = torch.zeros( + (number_of_atoms, avf_v.shape[1], avf_v.shape[2]), + device=device, + dtype=avf_v.dtype, + ) + # Aggregate per atom by summing the vectors + avf_v_sum.index_add_(0, idx_j, avf_v) # Shape: (number_of_atoms, H, 3) + + # Compute the norm over the last dimension (vector components) + vector_contributions = torch.norm( + avf_v_sum, dim=-1 + ) # Shape: (number_of_atoms, H) + + return vector_contributions + + def calculate_contributions( + self, + atomic_embedding: Tensor, + pair_indices: Tensor, + gs: Tensor, + gv: Tensor, + agh: Tensor, + calculate_vector_contributions: bool, + ) -> Tuple[Tensor, Tensor]: + idx_j = pair_indices[1] + a_j = atomic_embedding[idx_j] # Shape: (number_of_pairs, F_atom) + + radial_contributions = self.calculate_radial_contributions( + gs, + a_j, + atomic_embedding.shape[0], + idx_j, + ) + + if calculate_vector_contributions: + vector_contributions = self.calculate_vector_contributions( + gv, + a_j, + idx_j, + agh, + number_of_atoms=atomic_embedding.shape[0], + device=atomic_embedding.device, + ) + else: + # Return zeros with shape (number_of_atoms, number_of_vector_features) + vector_contributions = torch.zeros( + (atomic_embedding.shape[0], self.number_of_vector_features), + device=atomic_embedding.device, + ) + + return radial_contributions, vector_contributions + + def forward( + self, + atomic_embedding: Tensor, + partial_charges: Tensor, + pair_indices: Tensor, + gs: Tensor, + gv: Tensor, + agh: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor]: + + # Calculate contributions from embeddings + radial_contributions_emb, vector_contributions_emb = ( + self.calculate_contributions( + atomic_embedding, + pair_indices, + gs, + gv, + agh, + calculate_vector_contributions=True, + ) + ) + + if not self.is_first_module: + # Calculate contributions from charges + radial_contributions_charge, vector_contributions_charge = ( + self.calculate_contributions( + partial_charges, + pair_indices, + gs, + gv, + agh, + calculate_vector_contributions=False, + ) + ) + # Combine messages + combined_message = torch.cat( + [ + radial_contributions_emb, # (N, F_atom) + vector_contributions_emb, # (N, H) + radial_contributions_charge, # (N, 1) + vector_contributions_charge, # (N, H) + ], + dim=1, + ) + else: + combined_message = torch.cat( + [ + radial_contributions_emb, # (N, F_atom) + vector_contributions_emb, # (N, H) + ], + dim=1, + ) + + # Pass combined message through single MLP + out = self.mlp(combined_message) + + # Split the output tensor into delta_q, f, and delta_a + delta_q, f, delta_a = torch.split( + out, [1, 1, self.number_of_per_atom_features], dim=1 + ) + + return delta_a, delta_q, f + + +class AIMNet2Representation(nn.Module): + def __init__( + self, + radial_cutoff: float, + number_of_radial_basis_functions: int, + featurization_config: Dict[str, Dict[str, int]], + ): + """ + Initialize the AIMNet2 representation layer. + + Parameters + ---------- + radial_cutoff : float + The cutoff distance for the radial symmetry function in nanometer. + number_of_radial_basis_functions : int + Number of radial basis functions to use. + featurization_config : Dict[str, Union[List[str], int]] + Configuration for the featurization process. + """ + super().__init__() + + self.radial_symmetry_function_module = self._setup_radial_symmetry_functions( + radial_cutoff, number_of_radial_basis_functions + ) + # Initialize cutoff module + from modelforge.potential import CosineAttenuationFunction + from modelforge.potential.featurization import FeaturizeInput + + self.featurize_input = FeaturizeInput(featurization_config) + self.cutoff_module = CosineAttenuationFunction(radial_cutoff) + + def _setup_radial_symmetry_functions( + self, radial_cutoff: float, number_of_radial_basis_functions: int + ): + from modelforge.potential import SchnetRadialBasisFunction + + radial_symmetry_function = SchnetRadialBasisFunction( + number_of_radial_basis_functions=number_of_radial_basis_functions, + max_distance=radial_cutoff, + dtype=torch.float32, + ) + return radial_symmetry_function + + def forward( + self, + data: NNPInput, + pairlist_output: PairlistData, + ) -> Dict[str, torch.Tensor]: + """ + Generate the radial symmetry representation of the pairwise distances. + + Parameters + ---------- + data : NNPInput + The input data including atomic positions and numbers. + pairlist_output : PairlistData + Pairwise distances between atoms and pair indices. + + Returns + ------- + Dict[str, torch.Tensor] + The radial basis functions and atomic embeddings. + """ + + # Convert distances to radial basis functions + f_ij = self.radial_symmetry_function_module(pairlist_output.d_ij) + # Apply cutoff function to radial basis + f_cutoff = self.cutoff_module(pairlist_output.d_ij) + + return { + "f_ij": f_ij, + "f_cutoff": f_cutoff, + "atomic_embedding": self.featurize_input( + data + ), # add per-atom properties and embedding + } diff --git a/modelforge/potential/ani.py b/modelforge/potential/ani.py index 4a1f0748..2d4d0ebe 100644 --- a/modelforge/potential/ani.py +++ b/modelforge/potential/ani.py @@ -1,30 +1,55 @@ -from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, Tuple -from .models import InputPreparation, BaseNetwork, CoreNetwork +""" +This module contains the classes for the ANI2x neural network potential. +The ANI2x architecture is used for neural network potentials that compute atomic +energies based on Atomic Environment Vectors (AEVs). It supports multiple +species and interaction types, and allows prediction of properties like energy +using a neural network model. +""" + +from typing import Dict, Tuple, List import torch from loguru import logger as log from torch import nn -from modelforge.utils.prop import SpeciesAEV +from modelforge.utils.prop import SpeciesAEV, NNPInput + +from modelforge.potential.neighbors import PairlistData + + +def init_params(m): + if isinstance(m, torch.nn.Linear): + torch.nn.init.kaiming_normal_(m.weight, a=1.0) + torch.nn.init.zeros_(m.bias) -if TYPE_CHECKING: - from modelforge.dataset.dataset import NNPInput - from .models import PairListOutputs +def triu_index(number_of_atoms: int) -> torch.Tensor: + """ + Generate a tensor representing the upper triangular indices for species + pairs. This is used for computing angular symmetry features, where pairwise + combinations of species need to be considered. + + Parameters + ---------- + num_species : int + The number of species in the system. -def triu_index(num_species: int) -> torch.Tensor: - species1, species2 = torch.triu_indices(num_species, num_species).unbind(0) + Returns + ------- + torch.Tensor + A tensor containing the pair indices. + """ + species1, species2 = torch.triu_indices(number_of_atoms, number_of_atoms).unbind(0) pair_index = torch.arange(species1.shape[0], dtype=torch.long) - ret = torch.zeros(num_species, num_species, dtype=torch.long) + ret = torch.zeros(number_of_atoms, number_of_atoms, dtype=torch.long) ret[species1, species2] = pair_index ret[species2, species1] = pair_index return ret -from modelforge.potential.utils import NeuralNetworkData - +# A map from atomic number to an internal index used for species-specific +# computations. ATOMIC_NUMBER_TO_INDEX_MAP = { 1: 0, # H 6: 1, # C @@ -36,109 +61,163 @@ def triu_index(num_species: int) -> torch.Tensor: } -@dataclass -class AniNeuralNetworkData(NeuralNetworkData): +class ANIRepresentation(nn.Module): """ - A dataclass to structure the inputs for ANI neural network potentials, designed to - facilitate the efficient representation of atomic systems for energy computation and - property prediction. + Compute the Atomic Environment Vectors (AEVs) for the ANI architecture. AEVs + are representations of the local atomic environment used as input to the + neural network. - Attributes + Parameters ---------- - pair_indices : torch.Tensor - A 2D tensor indicating the indices of atom pairs. Shape: [2, num_pairs]. - d_ij : torch.Tensor - A 1D tensor containing distances between each pair of atoms. Shape: [num_pairs, 1]. - r_ij : torch.Tensor - A 2D tensor representing displacement vectors between atom pairs. Shape: [num_pairs, 3]. - number_of_atoms : int - An integer indicating the number of atoms in the batch. - positions : torch.Tensor - A 2D tensor representing the XYZ coordinates of each atom. Shape: [num_atoms, 3]. - atom_index : torch.Tensor - A 1D tensor containing atomic numbers for each atom in the system(s). Shape: [num_atoms]. - atomic_subsystem_indices : torch.Tensor - A 1D tensor mapping each atom to its respective subsystem or molecule. Shape: [num_atoms]. - total_charge : torch.Tensor - An tensor with the total charge of each system or molecule. Shape: [num_systems]. - atomic_numbers : torch.Tensor - A 1D tensor containing the atomic numbers for atoms, used for identifying the atom types within the model. Shape: [num_atoms]. - - Notes - ----- - The `AniNeuralNetworkInput` dataclass encapsulates essential inputs required by the - ANI neural network model to predict system energies and properties accurately. It - includes atomic positions, types, and connectivity information, crucial for representing - atomistic systems in detail. - - Examples - -------- - >>> ani_input = AniNeuralNetworkData( - ... pair_indices=torch.tensor([[0, 1], [0, 2], [1, 2]]).T, # Transpose for correct shape - ... d_ij=torch.tensor([[1.0], [1.0], [1.0]]), # Distances between pairs - ... r_ij=torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]), # Displacement vectors - ... number_of_atoms=4, # Total number of atoms - ... positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]]), - ... atom_index=torch.tensor([1, 6, 6, 8]), # Atomic numbers for H, C, C, O - ... atomic_subsystem_indices=torch.tensor([0, 0, 0, 0]), # All atoms belong to the same molecule - ... total_charge=torch.tensor([0.0]), # Assuming the molecule is neutral - ... atomic_numbers=torch.tensor([1, 6, 6, 8]) # Repeated for completeness - ... ) + radial_max_distance : float + The maximum distance for radial symmetry functions in nanometer. + radial_min_distance : float + The minimum distance for radial symmetry functions in nanometer. + number_of_radial_basis_functions : int + The number of radial basis functions. + maximum_interaction_radius_for_angular_features : float + The maximum interaction radius for angular features in nanometer. + minimum_interaction_radius_for_angular_features : float + The minimum interaction radius for angular features in nanometer. + angular_dist_divisions : int + The number of angular distance divisions. + angle_sections : int + The number of angle sections. + nr_of_supported_elements : int, optional + The number of supported elements, by default 7. """ - atom_index: torch.Tensor - - -from openff.units import unit - - -class ANIRepresentation(nn.Module): - # calculate the atomic environment vectors - # used for the ANI architecture of NNPs - def __init__( self, - radial_max_distance: unit.Quantity, - radial_min_distanc: unit.Quantity, + radial_max_distance: float, + radial_min_distance: float, number_of_radial_basis_functions: int, - angular_max_distance: unit.Quantity, - angular_min_distance: unit.Quantity, + maximum_interaction_radius_for_angular_features: float, + minimum_interaction_radius_for_angular_features: float, angular_dist_divisions: int, angle_sections: int, nr_of_supported_elements: int = 7, ): - # radial symmetry functions - super().__init__() - from modelforge.potential.utils import CosineCutoff + from modelforge.potential import CosineAttenuationFunction - self.angular_max_distance = angular_max_distance + self.maximum_interaction_radius_for_angular_features = ( + maximum_interaction_radius_for_angular_features + ) self.nr_of_supported_elements = nr_of_supported_elements - self.cutoff_module = CosineCutoff(radial_max_distance) + self.cutoff_module = CosineAttenuationFunction(radial_max_distance) + # Initialize radial and angular symmetry functions self.radial_symmetry_functions = self._setup_radial_symmetry_functions( - radial_max_distance, radial_min_distanc, number_of_radial_basis_functions + radial_max_distance, radial_min_distance, number_of_radial_basis_functions ) self.angular_symmetry_functions = self._setup_angular_symmetry_functions( - angular_max_distance, - angular_min_distance, + maximum_interaction_radius_for_angular_features, + minimum_interaction_radius_for_angular_features, angular_dist_divisions, angle_sections, ) - # generate indices - from modelforge.potential.utils import triple_by_molecule - - self.triple_by_molecule = triple_by_molecule + # Generate indices for species pairs self.register_buffer("triu_index", triu_index(self.nr_of_supported_elements)) + @staticmethod + def _cumsum_from_zero(input_: torch.Tensor) -> torch.Tensor: + """ + Compute the cumulative sum from zero, used for sorting indices. + """ + cumsum = torch.zeros_like(input_) + torch.cumsum(input_[:-1], dim=0, out=cumsum[1:]) + return cumsum + + @staticmethod + def triple_by_molecule( + atom_pairs: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Convert pairwise indices to central-others format for angular feature + computation. This method rearranges pairwise atomic indices for angular + symmetry functions. + + NOTE: this function is adopted from torchani library: + https://github.com/aiqm/torchani/blob/17204c6dccf6210753bc8c0ca4c92278b60719c9/torchani/aev.py + distributed under the MIT license. + + . + Parameters + ---------- + atom_pairs : torch.Tensor + A tensor of atom pair indices. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor] + Central atom indices, local pair indices, and sign of the pairs. + """ + + # convert representation from pair to central-others + ai1 = atom_pairs.view(-1) + + # Note, torch.sort doesn't guarantee stable sort by default. This means + # that the order of rev_indices is not guaranteed when there are "ties" + # (i.e., identical values in the input tensor). Stable sort is more + # expensive and ultimately unnecessary, so we will not use it here, but + # it does mean that vector-wise comparison of the outputs of this + # function may be inconsistent for the same input, and thus tests must + # be designed accordingly. + + sorted_ai1, rev_indices = ai1.sort() + + # sort and compute unique key + uniqued_central_atom_index, counts = torch.unique_consecutive( + sorted_ai1, return_inverse=False, return_counts=True + ) + + # compute central_atom_index + pair_sizes = torch.div(counts * (counts - 1), 2, rounding_mode="trunc") + pair_indices = torch.repeat_interleave(pair_sizes) + central_atom_index = uniqued_central_atom_index.index_select(0, pair_indices) + + # do local combinations within unique key, assuming sorted + m = counts.max().item() if counts.numel() > 0 else 0 + n = pair_sizes.shape[0] + intra_pair_indices = ( + torch.tril_indices(m, m, -1, device=ai1.device) + .unsqueeze(1) + .expand(-1, n, -1) + ) + mask = ( + torch.arange(intra_pair_indices.shape[2], device=ai1.device) + < pair_sizes.unsqueeze(1) + ).flatten() + sorted_local_index12 = intra_pair_indices.flatten(1, 2)[:, mask] + sorted_local_index12 += ANIRepresentation._cumsum_from_zero( + counts + ).index_select(0, pair_indices) + + # unsort result from last part + local_index12 = rev_indices[sorted_local_index12] + + # compute mapping between representation of central-other to pair + n = atom_pairs.shape[1] + sign12 = ((local_index12 < n).to(torch.int8) * 2) - 1 + return central_atom_index, local_index12 % n, sign12 + def _setup_radial_symmetry_functions( self, - max_distance: unit.Quantity, - min_distance: unit.Quantity, + max_distance: float, + min_distance: float, number_of_radial_basis_functions: int, ): - from .utils import AniRadialBasisFunction + """ + Initialize the radial symmetry function block. + Parameters + ---------- + max_distance : float + min_distance: float + number_of_radial_basis_functions : int + """ + from .representation import AniRadialBasisFunction radial_symmetry_function = AniRadialBasisFunction( number_of_radial_basis_functions, @@ -150,12 +229,12 @@ def _setup_radial_symmetry_functions( def _setup_angular_symmetry_functions( self, - max_distance: unit.Quantity, - min_distance: unit.Quantity, - angular_dist_divisions, - angle_sections, + max_distance: float, + min_distance: float, + angular_dist_divisions: int, + angle_sections: int, ): - from .utils import AngularSymmetryFunction + from .representation import AngularSymmetryFunction # set up modelforge angular features return AngularSymmetryFunction( @@ -166,28 +245,50 @@ def _setup_angular_symmetry_functions( dtype=torch.float32, ) - def forward(self, data: AniNeuralNetworkData) -> SpeciesAEV: - # calculate the atomic environment vectors - # used for the ANI architecture of NNPs + def forward( + self, + data: NNPInput, + pairlist_output: PairlistData, + atom_index: torch.Tensor, + ) -> SpeciesAEV: + """ + Forward pass to compute Atomic Environment Vectors (AEVs). + + Parameters + ---------- + data : NNPInput + The input data for the ANI model. + pairlist_output : PairlistData + Pairwise distances and displacement vectors. + atom_index : torch.Tensor + Indices of atomic species. + + Returns + ------- + SpeciesAEV + The computed atomic environment vectors (AEVs) for each species. + """ # ----------------- Radial symmetry vector ---------------- # # compute radial aev - - radial_feature_vector = self.radial_symmetry_functions(data.d_ij) - # cutoff - rcut_ij = self.cutoff_module(data.d_ij) + radial_feature_vector = self.radial_symmetry_functions(pairlist_output.d_ij) + # Apply cutoff to radial features + rcut_ij = self.cutoff_module(pairlist_output.d_ij) radial_feature_vector = radial_feature_vector * rcut_ij - # process output to prepare for agular symmetry vector + # Process output to prepare for angular symmetry vector postprocessed_radial_aev_and_additional_data = self._postprocess_radial_aev( - radial_feature_vector, data=data + radial_feature_vector, + data=data, + atom_index=atom_index, + pairlist_output=pairlist_output, ) processed_radial_feature_vector = postprocessed_radial_aev_and_additional_data[ "radial_aev" ] # ----------------- Angular symmetry vector ---------------- # - # preprocess + # Compute angular AEV angular_data = self._preprocess_angular_aev( postprocessed_radial_aev_and_additional_data ) @@ -200,17 +301,33 @@ def forward(self, data: AniNeuralNetworkData) -> SpeciesAEV: processed_angular_feature_vector = self._postprocess_angular_aev( data, angular_data ) + # Concatenate radial and angular features aevs = torch.cat( [processed_radial_feature_vector, processed_angular_feature_vector], dim=-1 ) - return SpeciesAEV(data.atom_index, aevs) + return SpeciesAEV(atom_index, aevs) def _postprocess_angular_aev( - self, data: AniNeuralNetworkData, angular_data: Dict[str, torch.Tensor] + self, + data: NNPInput, + angular_data: Dict[str, torch.Tensor], ): - # postprocess the angular aev - # used for the ANI architecture of NNPs + """ + Postprocess the angular AEVs. + + Parameters + ---------- + data : NNPInput + The input data. + angular_data : Dict[str, torch.Tensor] + The angular data including species and displacement vectors. + + Returns + ------- + torch.Tensor + The processed angular AEVs. + """ angular_sublength = self.angular_symmetry_functions.angular_sublength angular_length = ( (self.nr_of_supported_elements * (self.nr_of_supported_elements + 1)) @@ -220,13 +337,15 @@ def _postprocess_angular_aev( num_species_pairs = angular_length // angular_sublength - number_of_atoms = data.number_of_atoms + number_of_atoms = data.atomic_numbers.shape[0] + # compute angular aev central_atom_index = angular_data["central_atom_index"] angular_species12 = angular_data["angular_species12"] angular_r_ij = angular_data["angular_r_ij"] angular_terms_ = angular_data["angular_feature_vector"] + # Initialize tensor to store angular AEVs angular_aev = angular_terms_.new_zeros( (number_of_atoms * num_species_pairs, angular_sublength) @@ -243,25 +362,55 @@ def _postprocess_angular_aev( def _postprocess_radial_aev( self, radial_feature_vector: torch.Tensor, - data: AniNeuralNetworkData, - ) -> Dict[str, torch.tensor]: - radial_feature_vector = radial_feature_vector.squeeze(1) - number_of_atoms = data.number_of_atoms + data: NNPInput, + atom_index: torch.Tensor, + pairlist_output: PairlistData, + ) -> Dict[str, torch.Tensor]: + """ + Postprocess the radial AEVs. + + Parameters + ---------- + radial_feature_vector : torch.Tensor + The radial feature vectors. + data : NNPInput + The input data for the ANI model. + + Returns + ------- + Dict[str, torch.Tensor] + A dictionary containing the radial AEVs and additional data. + """ + radial_feature_vector = radial_feature_vector.squeeze( + 1 + ) # Shape [num_pairs, radial_sublength] + number_of_atoms = data.atomic_numbers.shape[0] radial_sublength = ( self.radial_symmetry_functions.number_of_radial_basis_functions ) radial_length = radial_sublength * self.nr_of_supported_elements + # Initialize tensor to store radial AEVs radial_aev = radial_feature_vector.new_zeros( ( number_of_atoms * self.nr_of_supported_elements, radial_sublength, ) - ) - atom_index12 = data.pair_indices - species = data.atom_index - species12 = species[atom_index12] + ) # Shape [num_atoms * nr_of_supported_elements, radial_sublength] + + atom_index12 = ( + pairlist_output.pair_indices + ) # Shape [2, num_pairs] # this is the pair list of the atoms (e.g. C=6) + species = atom_index + species12 = species[ + atom_index12 + ] # Shape [2, num_pairs], this is the pair index but now with optimzied indexing + # What are we doing here? we generate an atomic environment vector with + # fixed dimensinos (nr_of_supported_elements, 16 (represents number of + # radial symmetry functions)) for each **element** per atom (in a pair) + + # this is a magic indexing function that works index12 = atom_index12 * self.nr_of_supported_elements + species12.flip(0) radial_aev.index_add_(0, index12[0], radial_feature_vector) radial_aev.index_add_(0, index12[1], radial_feature_vector) @@ -269,25 +418,34 @@ def _postprocess_radial_aev( radial_aev = radial_aev.reshape(number_of_atoms, radial_length) # compute new neighbors with radial_cutoff - distances = data.d_ij.T.flatten() + distances = pairlist_output.d_ij.T.flatten() even_closer_indices = ( - (distances <= self.angular_max_distance.to(unit.nanometer).m) + (distances <= self.maximum_interaction_radius_for_angular_features) .nonzero() .flatten() ) - r_ij = data.r_ij - atom_index12 = atom_index12.index_select(1, even_closer_indices) - species12 = species12.index_select(1, even_closer_indices) - r_ij_small = r_ij.index_select(0, even_closer_indices) return { "radial_aev": radial_aev, - "atom_index12": atom_index12, - "species12": species12, - "r_ij": r_ij_small, + "atom_index12": atom_index12.index_select(1, even_closer_indices), + "species12": species12.index_select(1, even_closer_indices), + "r_ij": pairlist_output.r_ij.index_select(0, even_closer_indices), } def _preprocess_angular_aev(self, data: Dict[str, torch.Tensor]): + """ + Preprocess the angular AEVs. + + Parameters + ---------- + data : Dict[str, torch.Tensor] + The data dictionary containing radial AEVs and additional data. + + Returns + ------- + Dict[str, torch.Tensor] + A dictionary containing the preprocessed angular AEV data. + """ atom_index12 = data["atom_index12"] species12 = data["species12"] r_ij = data["r_ij"] @@ -309,311 +467,425 @@ def _preprocess_angular_aev(self, data: Dict[str, torch.Tensor]): } +class MultiOutputHeadNetwork(nn.Module): + + def __init__(self, shared_layers: nn.Sequential, output_dims: int): + """ + A neural network module with multiple output heads for property prediction. + + This network shares a common set of layers and then splits into multiple + heads, each of which predicts a different output property. + + Parameters + ---------- + shared_layers : nn.Sequential + The shared layers before branching into the output heads. + output_dims : int + The number of output properties (dimensions) to predict. + """ + super().__init__() + self.shared_layers = shared_layers + # The input dimension is the output dimension of the last shared layer + input_dim = shared_layers[ + -2 + ].out_features # Get the output dim from the last shared layer + + # Create a list of output heads, one for each predicted property + self.output_heads = nn.ModuleList( + [nn.Linear(input_dim, 1) for _ in range(output_dims)] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the multi-output head network. + + The input is processed by the shared layers, and each output head generates + a prediction for one property. + + Parameters + ---------- + x : torch.Tensor + The input tensor. + + Returns + ------- + torch.Tensor + A concatenated tensor of predictions from all output heads. + """ + # Pass the input through the shared layers + x = self.shared_layers(x) + # Get the output from each head and concatenate along the last dimension + outputs = [head(x) for head in self.output_heads] + return torch.cat(outputs, dim=1) + + class ANIInteraction(nn.Module): - def __init__(self, aev_dim: int): + + def __init__( + self, + *, + aev_dim: int, + activation_function: torch.nn.Module, + predicted_properties: List[str], + predicted_dim: List[int], + ): + """ + Atomic neural network interaction module for ANI. + + This module applies a neural network to the Atomic Environment Vectors + (AEVs) to compute atomic properties like energy. + + Parameters + ---------- + aev_dim : int + The dimensionality of the AEVs. + activation_function : torch.nn.Module + The activation function to use in the neural network layers. + predicted_properties : List[str] + The names of the properties that the network will predict. + predicted_dim : List[int] + The dimensions of each predicted property. + """ super().__init__() + output_dim = int(sum(predicted_dim)) + self.predicted_properties = predicted_properties # define atomic neural network - atomic_neural_networks = self.intialize_atomic_neural_network(aev_dim) - H_network = atomic_neural_networks["H"] - C_network = atomic_neural_networks["C"] - O_network = atomic_neural_networks["O"] - N_network = atomic_neural_networks["N"] - S_network = atomic_neural_networks["S"] - F_network = atomic_neural_networks["F"] - Cl_network = atomic_neural_networks["Cl"] + atomic_neural_networks = self.intialize_atomic_neural_network( + aev_dim, activation_function, output_dim + ) + # Initialize atomic neural networks for each element in the supported + # species self.atomic_networks = nn.ModuleList( [ - H_network, - C_network, - O_network, - N_network, - S_network, - F_network, - Cl_network, + atomic_neural_networks[element] + for element in ["H", "C", "O", "N", "S", "F", "Cl"] ] ) - def intialize_atomic_neural_network(self, aev_dim: int) -> Dict[str, nn.Module]: - H_network = torch.nn.Sequential( - torch.nn.Linear(aev_dim, 256), - torch.nn.CELU(0.1), - torch.nn.Linear(256, 192), - torch.nn.CELU(0.1), - torch.nn.Linear(192, 160), - torch.nn.CELU(0.1), - torch.nn.Linear(160, 1), - ) + def intialize_atomic_neural_network( + self, + aev_dim: int, + activation_function: torch.nn.Module, + output_dim: int, + ) -> Dict[str, nn.Module]: + """ + Initialize the atomic neural networks for each chemical element. - C_network = torch.nn.Sequential( - torch.nn.Linear(aev_dim, 224), - torch.nn.CELU(0.1), - torch.nn.Linear(224, 192), - torch.nn.CELU(0.1), - torch.nn.Linear(192, 160), - torch.nn.CELU(0.1), - torch.nn.Linear(160, 1), - ) + Each element gets a separate neural network to predict properties based + on the AEVs. - N_network = torch.nn.Sequential( - torch.nn.Linear(aev_dim, 192), - torch.nn.CELU(0.1), - torch.nn.Linear(192, 160), - torch.nn.CELU(0.1), - torch.nn.Linear(160, 128), - torch.nn.CELU(0.1), - torch.nn.Linear(128, 1), - ) + Parameters + ---------- + aev_dim : int + The dimensionality of the AEVs. + activation_function : torch.nn.Module + The activation function to use. + output_dim : int + The output dimensionality for each neural network (sum of all + predicted properties). - O_network = torch.nn.Sequential( - torch.nn.Linear(aev_dim, 192), - torch.nn.CELU(0.1), - torch.nn.Linear(192, 160), - torch.nn.CELU(0.1), - torch.nn.Linear(160, 128), - torch.nn.CELU(0.1), - torch.nn.Linear(128, 1), - ) + Returns + ------- + Dict[str, nn.Module] + A dictionary mapping element symbols to their corresponding neural + networks. + """ - S_network = torch.nn.Sequential( - torch.nn.Linear(aev_dim, 160), - torch.nn.CELU(0.1), - torch.nn.Linear(160, 128), - torch.nn.CELU(0.1), - torch.nn.Linear(128, 96), - torch.nn.CELU(0.1), - torch.nn.Linear(96, 1), - ) + def create_network(layers: List[int]) -> nn.Module: + """ + Create a sequential neural network with the specified number of + layers. + + Each layer consists of a linear transformation followed by an + activation function. + + Parameters + ---------- + layers : List[int] + A list where each element is the number of units in the + corresponding layer. + + Returns + ------- + nn.Sequential + A sequential neural network with the specified layers. + """ + shared_network_layers = [] + input_dim = aev_dim + for units in layers: + shared_network_layers.append(nn.Linear(input_dim, units)) + shared_network_layers.append(activation_function) + input_dim = units + + # Return a MultiOutputHeadNetwork with shared layers and specified + # output dimensions + shared_layers = nn.Sequential(*shared_network_layers) + return MultiOutputHeadNetwork(shared_layers, output_dims=output_dim) + + # Define layer configurations for different elements + return { + element: create_network(layers) + for element, layers in { + "H": [256, 192, 160], + "C": [224, 192, 160], + "N": [192, 160, 128], + "O": [192, 160, 128], + "S": [160, 128, 96], + "F": [160, 128, 96], + "Cl": [160, 128, 96], + }.items() + } - F_network = torch.nn.Sequential( - torch.nn.Linear(aev_dim, 160), - torch.nn.CELU(0.1), - torch.nn.Linear(160, 128), - torch.nn.CELU(0.1), - torch.nn.Linear(128, 96), - torch.nn.CELU(0.1), - torch.nn.Linear(96, 1), - ) + def forward(self, input: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + """ + Forward pass to compute atomic properties from AEVs. - Cl_network = torch.nn.Sequential( - torch.nn.Linear(aev_dim, 160), - torch.nn.CELU(0.1), - torch.nn.Linear(160, 128), - torch.nn.CELU(0.1), - torch.nn.Linear(128, 96), - torch.nn.CELU(0.1), - torch.nn.Linear(96, 1), - ) + For each species, the corresponding atomic neural network is used to + predict properties. - return { - "H": H_network, - "C": C_network, - "N": N_network, - "O": O_network, - "S": S_network, - "F": F_network, - "Cl": Cl_network, - } + Parameters + ---------- + input : Tuple[torch.Tensor, torch.Tensor] + A tuple containing the species tensor and the AEV tensor. - def forward(self, input: Tuple[torch.Tensor, torch.Tensor]): + Returns + ------- + torch.Tensor + The computed atomic properties for each atom. + """ species, aev = input - output = aev.new_zeros(species.shape) + per_atom_property = torch.zeros( + (species.shape[0], len(self.predicted_properties)), + dtype=aev.dtype, + device=aev.device, + ) for i, model in enumerate(self.atomic_networks): - mask = torch.eq(species, i) - midx = mask.nonzero().flatten() - if midx.shape[0] > 0: - input_ = aev.index_select(0, midx) - output[midx] = model(input_).flatten() + # create a mask to select the atoms of the current species (i) + mask = species == i + per_element_index = mask.nonzero().flatten() + # if the species is present in the batch, run it through the network + if per_element_index.shape[0] > 0: + input_ = aev.index_select(0, per_element_index) + per_element_predction = model(input_) + # Accumulate predictions in per_atom_property + per_atom_property.index_add_( + 0, + per_element_index, + per_element_predction, + ) + + return per_atom_property + - return output.view_as(species) +from typing import List -class ANI2xCore(CoreNetwork): +class ANI2xCore(torch.nn.Module): + def __init__( self, - radial_max_distance: unit.Quantity = 5.1 * unit.angstrom, - radial_min_distanc: unit.Quantity = 0.8 * unit.angstrom, - number_of_radial_basis_functions: int = 16, - angular_max_distance: unit.Quantity = 3.5 * unit.angstrom, - angular_min_distance: unit.Quantity = 0.8 * unit.angstrom, - angular_dist_divisions: int = 8, - angle_sections: int = 4, + *, + maximum_interaction_radius: float, + minimum_interaction_radius: float, + number_of_radial_basis_functions: int, + maximum_interaction_radius_for_angular_features: float, + minimum_interaction_radius_for_angular_features: float, + activation_function_parameter: Dict[str, str], + angular_dist_divisions: int, + predicted_properties: List[str], + predicted_dim: List[int], + angle_sections: int, + potential_seed: int = -1, ) -> None: """ - ANI2x Neural Network Model. + The main core module for the ANI2x architecture. + + ANI2x computes atomic properties (like energy) based on Atomic + Environment Vectors (AEVs), with support for multiple atomic species. + Parameters ---------- - radial_max_distance : Union[unit.Quantity, str] - The maximum radial distance for the radial basis functions. - radial_min_distance : Union[unit.Quantity, str] - The minimum radial distance for the radial basis functions. + maximum_interaction_radius : float + The maximum interaction radius for radial symmetry functions. + minimum_interaction_radius : float + The minimum interaction radius for radial symmetry functions. number_of_radial_basis_functions : int - The number of radial basis functions to use. - angular_max_distance : Union[unit.Quantity, str] - The maximum angular distance for the angular basis functions. - angular_min_distance : Union[unit.Quantity, str] - The minimum angular distance for the angular basis functions. + The number of radial basis functions. + maximum_interaction_radius_for_angular_features : float + The maximum interaction radius for angular symmetry functions. + minimum_interaction_radius_for_angular_features : float + The minimum interaction radius for angular symmetry functions. + activation_function_parameter : Dict[str, str] + A dictionary specifying the activation function to use. angular_dist_divisions : int - The number of divisions for the angular distance. + The number of angular distance divisions. + predicted_properties : List[str] + A list of property names that the model will predict. + predicted_dim : List[int] + A list of dimensions for each predicted property. angle_sections : int - The number of angle sections to use. - processing_operation : List[Dict[str, str]] - A list of processing operations to apply to the input data. - readout_operation : List[Dict[str, str]] - A list of readout operations to apply to the output data. - dataset_statistic : Optional[Dict[str, float]], optional - Optional dataset statistics to use for normalization, by default None. - """ - # number of elements in ANI2x - self.num_species = 7 - - log.debug("Initializing ANI model.") + The number of angular sections for the angular symmetry functions. + potential_seed : int, optional + A seed for random number generation, by default -1. + """ + + from modelforge.utils.misc import seed_random_number + + if potential_seed != -1: + seed_random_number(potential_seed) + super().__init__() - # Initialize representation block + self.num_species = 7 # Number of elements supported by ANI2x + self.predicted_dim = predicted_dim + + self.activation_function = activation_function_parameter["activation_function"] + + log.debug("Initializing the ANI2x architecture.") + self.predicted_properties = predicted_properties + + # Initialize the representation block (AEVs) self.ani_representation_module = ANIRepresentation( - radial_max_distance, - radial_min_distanc, + maximum_interaction_radius, + minimum_interaction_radius, number_of_radial_basis_functions, - angular_max_distance, - angular_min_distance, + maximum_interaction_radius_for_angular_features, + minimum_interaction_radius_for_angular_features, angular_dist_divisions, angle_sections, ) - # The length of radial aev - self.radial_length = self.num_species * number_of_radial_basis_functions - # The length of angular aev - self.angular_length = ( + # Calculate the dimensions of the radial and angular AEVs + radial_length = self.num_species * number_of_radial_basis_functions + angular_length = ( (self.num_species * (self.num_species + 1)) // 2 * self.ani_representation_module.angular_symmetry_functions.angular_sublength ) - - # The length of full aev - self.aev_length = self.radial_length + self.angular_length - - # Intialize interaction blocks - self.interaction_modules = ANIInteraction(self.aev_length) + aev_length = radial_length + angular_length + + # Initialize interaction modules for predicting properties from AEVs + self.interaction_modules = ANIInteraction( + aev_dim=aev_length, + activation_function=self.activation_function, + predicted_properties=predicted_properties, + predicted_dim=predicted_dim, + ) # ----- ATOMIC NUMBER LOOKUP -------- # Create a tensor for direct lookup. The size of this tensor will be # # the max atomic number in map. Initialize with a default value (e.g., -1 for not found). - max_atomic_number = max(ATOMIC_NUMBER_TO_INDEX_MAP.keys()) - lookup_tensor = torch.full((max_atomic_number + 1,), -1, dtype=torch.long) + maximum_atomic_number = max(ATOMIC_NUMBER_TO_INDEX_MAP.keys()) + lookup_tensor = torch.full((maximum_atomic_number + 1,), -1, dtype=torch.long) # Populate the lookup tensor with indices from your map for atomic_number, index in ATOMIC_NUMBER_TO_INDEX_MAP.items(): lookup_tensor[atomic_number] = index self.register_buffer("lookup_tensor", lookup_tensor) + # Apply the custom weight initialization + self.apply(init_params) - def _model_specific_input_preparation( - self, data: "NNPInput", pairlist_output: "PairListOutputs" - ) -> AniNeuralNetworkData: - number_of_atoms = data.atomic_numbers.shape[0] - - nnp_data = AniNeuralNetworkData( - pair_indices=pairlist_output.pair_indices, - d_ij=pairlist_output.d_ij, - r_ij=pairlist_output.r_ij, - number_of_atoms=number_of_atoms, - positions=data.positions, - atom_index=self.lookup_tensor[data.atomic_numbers.long()], - atomic_numbers=data.atomic_numbers, - atomic_subsystem_indices=data.atomic_subsystem_indices, - total_charge=data.total_charge, - ) - - return nnp_data - - def compute_properties(self, data: AniNeuralNetworkData) -> Dict[str, torch.Tensor]: + def compute_properties( + self, + data: NNPInput, + pairlist_output: PairlistData, + atom_index: torch.Tensor, + ) -> Dict[str, torch.Tensor]: """ - Calculate the energy for a given input batch. + Compute atomic properties (like energy) from AEVs. + + This is the main computation method, which processes the input data and + pairlist to generate per-atom predictions. Parameters ---------- - data : AniNeuralNetworkInput - - pairlist: shape (n_pairs, 2) - - r_ij: shape (n_pairs, 1) - - d_ij: shape (n_pairs, 3) - - positions: shape (nr_of_atoms_per_molecules, 3) + data : NNPInput + The input data for the ANI model, including atomic numbers and positions. + pairlist_output : PairlistData + The pairwise distances and displacement vectors between atoms. + atom_index : torch.Tensor + The indices of atomic species in the input data. Returns ------- - torch.Tensor - Calculated energies; shape (nr_systems,). + Dict[str, torch.Tensor] + The calculated per-atom properties and the scalar representation of AEVs. """ - # compute the representation (atomic environment vectors) for each atom - representation = self.ani_representation_module(data) - # compute the atomic energies - E_i = self.interaction_modules(representation) + # Compute AEVs (atomic environment vectors) + representation = self.ani_representation_module( + data, pairlist_output, atom_index + ) + # Use interaction modules to predict properties from AEVs + predictions = self.interaction_modules(representation) + # generate the output results return { - "per_atom_energy": E_i, + "per_atom_prediction": predictions, + "per_atom_scalar_representation": representation.aevs, "atomic_subsystem_indices": data.atomic_subsystem_indices, } + def _aggregate_results( + self, outputs: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + """ + Aggregate per-atom predictions into property-specific tensors. -from typing import Union, Optional, List, Dict - - -class ANI2x(BaseNetwork): - def __init__( - self, - radial_max_distance: Union[unit.Quantity, str], - radial_min_distance: Union[unit.Quantity, str], - number_of_radial_basis_functions: int, - angular_max_distance: Union[unit.Quantity, str], - angular_min_distance: Union[unit.Quantity, str], - angular_dist_divisions: int, - angle_sections: int, - postprocessing_parameter: Dict[str, Dict[str, bool]], - dataset_statistic: Optional[Dict[str, float]] = None, - ) -> None: - - from modelforge.utils.units import _convert - - self.only_unique_pairs = True # NOTE: need to be set before super().__init__ + This method splits the concatenated per-atom predictions into individual properties. - super().__init__( - dataset_statistic=dataset_statistic, - postprocessing_parameter=postprocessing_parameter, - cutoff=_convert(radial_max_distance), - ) + Parameters + ---------- + outputs : Dict[str, torch.Tensor] + A dictionary containing per-atom predictions. - self.core_module = ANI2xCore( - _convert(radial_max_distance), - _convert(radial_min_distance), - number_of_radial_basis_functions, - _convert(angular_max_distance), - _convert(angular_min_distance), - angular_dist_divisions, - angle_sections, + Returns + ------- + Dict[str, torch.Tensor] + A dictionary containing the split predictions for each property. + """ + # retrieve the per-atom predictions (nr_atoms, nr_properties) + per_atom_prediction = outputs.pop("per_atom_prediction") + # split the predictions into individual properties + split_tensors = torch.split(per_atom_prediction, self.predicted_dim, dim=1) + # update the outputs with the split predictions + outputs.update( + { + label: tensor + for label, tensor in zip(self.predicted_properties, split_tensors) + } ) + return outputs - def _config_prior(self): - log.info("Configuring ANI2x model hyperparameter prior distribution") - from modelforge.utils.io import import_ - - tune = import_("ray").tune - # from ray import tune + def forward( + self, data: NNPInput, pairlist_output: PairlistData + ) -> Dict[str, torch.Tensor]: + """ + Forward pass through the ANI2x model to compute atomic properties. - from modelforge.train.utils import shared_config_prior + This method combines the AEV computation with the property prediction + step. - prior = { - "radial_max_distance": tune.uniform(5, 10), - "radial_min_distance": tune.uniform(0.6, 1.4), - "number_of_radial_basis_functions": tune.randint(12, 20), - "angular_max_distance": tune.uniform(2.5, 4.5), - "angular_min_distance": tune.uniform(0.6, 1.4), - "angle_sections": tune.randint(3, 8), - } - prior.update(shared_config_prior()) - return prior + Parameters + ---------- + data : NNPInput + The input data for the model, including atomic numbers and + positions. + pairlist_output : PairlistData + The pairwise distance and displacement vectors between atoms. - def combine_per_atom_properties( - self, values: Dict[str, torch.Tensor] - ) -> torch.Tensor: - return values + Returns + ------- + Dict[str, torch.Tensor] + A dictionary of calculated properties, including per-atom + predictions and AEVs. + """ + atom_index = self.lookup_tensor[data.atomic_numbers.long()] + # perform the forward pass implemented in the subclass + outputs = self.compute_properties(data, pairlist_output, atom_index) + # add atomic numbers to the output + outputs["atomic_numbers"] = data.atomic_numbers + # extract predictions per property + return self._aggregate_results(outputs) diff --git a/modelforge/potential/bayesian_models.py b/modelforge/potential/bayesian_models.py index 2a6f285b..15f45a57 100644 --- a/modelforge/potential/bayesian_models.py +++ b/modelforge/potential/bayesian_models.py @@ -1,61 +1,61 @@ -import torch -import pyro -from pyro.nn.module import to_pyro_module_ - -def init_log_sigma(model, value): - """Initializes the log_sigma parameters of a model - - Parameters - ---------- - model : torch.nn.Module - The model to initialize - - value : float - The value to initialize the log_sigma parameters to - - """ - log_sigma_params = { - name + "_log_sigma": pyro.nn.PyroParam( - torch.ones(param.shape) * value, - ) - for name, param in model.named_parameters() - } - - for name, param in log_sigma_params.items(): - setattr(model, name, param) - -class BayesianAutoNormalPotential(torch.nn.Module): - """A Bayesian model with a normal prior and likelihood. - - Parameters - ---------- - log_sigma : float, optional - The initial value of the log_sigma parameters. Default is 0.0. - - Methods - ------- - model - The model function. If no `y` argument is provided, - provide the prior; if `y` is provided, provide the likelihood. - """ - def __init__( - self, - *args, **kwargs, - ): - super().__init__() - log_sigma = kwargs.pop("log_sigma", 0.0) - init_log_sigma(self, log_sigma) - - def model(self, *args, **kwargs): - """The model function. If no `y` argument is provided, - provide the prior; if `y` is provided, provide the likelihood. - """ - y = kwargs.pop("y", None) - y_hat = self(*args, **kwargs) - pyro.sample( - "obs", - pyro.distributions.Delta(y_hat), - obs=y - ) - - +# import torch +# import pyro +# from pyro.nn.module import to_pyro_module_ +# +# def init_log_sigma(model, value): +# """Initializes the log_sigma parameters of a model +# +# Parameters +# ---------- +# model : torch.nn.Module +# The model to initialize +# +# value : float +# The value to initialize the log_sigma parameters to +# +# """ +# log_sigma_params = { +# name + "_log_sigma": pyro.nn.PyroParam( +# torch.ones(param.shape) * value, +# ) +# for name, param in model.named_parameters() +# } +# +# for name, param in log_sigma_params.items(): +# setattr(model, name, param) +# +# class BayesianAutoNormalPotential(torch.nn.Module): +# """A Bayesian model with a normal prior and likelihood. +# +# Parameters +# ---------- +# log_sigma : float, optional +# The initial value of the log_sigma parameters. Default is 0.0. +# +# Methods +# ------- +# model +# The model function. If no `y` argument is provided, +# provide the prior; if `y` is provided, provide the likelihood. +# """ +# def __init__( +# self, +# *args, **kwargs, +# ): +# super().__init__() +# log_sigma = kwargs.pop("log_sigma", 0.0) +# init_log_sigma(self, log_sigma) +# +# def model(self, *args, **kwargs): +# """The model function. If no `y` argument is provided, +# provide the prior; if `y` is provided, provide the likelihood. +# """ +# y = kwargs.pop("y", None) +# y_hat = self(*args, **kwargs) +# pyro.sample( +# "obs", +# pyro.distributions.Delta(y_hat), +# obs=y +# ) +# +# diff --git a/modelforge/potential/featurization.py b/modelforge/potential/featurization.py new file mode 100644 index 00000000..c1e8a6cb --- /dev/null +++ b/modelforge/potential/featurization.py @@ -0,0 +1,230 @@ +import torch +import torch.nn as nn +from typing import Dict, List, Union + +from modelforge.potential.utils import DenseWithCustomDist +from modelforge.utils.prop import NNPInput + + +class AddPerMoleculeValue(nn.Module): + """ + Module that adds a per-molecule value to a per-atom property tensor. + The per-molecule value is expanded to match th elength of the per-atom property tensor. + + Parameters + ---------- + key : str + The key to access the per-molecule value from the input data. + + Attributes + ---------- + key : str + The key to access the per-molecule value from the input data. + """ + + def __init__(self, key: str): + super().__init__() + self.key = key + + def forward( + self, per_atom_property_tensor: torch.Tensor, data: NNPInput + ) -> torch.Tensor: + """ + Forward pass of the module. + + Parameters + ---------- + per_atom_property_tensor : torch.Tensor + The per-atom property tensor. + data : NNPInput + The input data containing the per-molecule value. + + Returns + ------- + torch.Tensor + The updated per-atom property tensor with the per-molecule value appended. + """ + values_to_append = getattr(data, self.key) + _, counts = torch.unique(data.atomic_subsystem_indices, return_counts=True) + expanded_values = torch.repeat_interleave(values_to_append, counts).unsqueeze(1) + return torch.cat((per_atom_property_tensor, expanded_values), dim=1) + + +class AddPerAtomValue(nn.Module): + """ + Module that adds a per-atom value to a tensor. + + Parameters + ---------- + key : str + The key to access the per-atom value from the input data. + + Attributes + ---------- + key : str + The key to access the per-atom value from the input data. + """ + + def __init__(self, key: str): + super().__init__() + self.key = key + + def forward( + self, per_atom_property_tensor: torch.Tensor, data: NNPInput + ) -> torch.Tensor: + """ + Forward pass of the module. + + Parameters + ---------- + per_atom_property_tensor : torch.Tensor + The input tensor representing per-atom properties. + data : NNPInput + The input data object containing additional information. + + Returns + ------- + torch.Tensor + The tensor with the per-atom value appended. + """ + values_to_append = getattr(data, self.key) + return torch.cat((per_atom_property_tensor, values_to_append), dim=1) + + +class FeaturizeInput(nn.Module): + + _SUPPORTED_FEATURIZATION_TYPES = [ + "atomic_number", + "per_system_total_charge", + "spin_state", + ] + + def __init__(self, featurization_config: Dict[str, Dict[str, int]]) -> None: + """ + Initialize the FeaturizeInput class. + + For per-atom non-categorical properties and per-molecule properties + (both categorical and non-categorical), we append the embedded nuclear + charges and mix them using a linear layer. + + For per-atom categorical properties, we define an additional embedding + and add the embedding to the nuclear charge embedding. + + Parameters + ---------- + featurization_config : dict + A dictionary containing the featurization configuration. It should + have the following keys: + - "properties_to_featurize" : list + A list of properties to featurize. + - "maximum_atomic_number" : int + The maximum atomic number. + - "number_of_per_atom_features" : int + The number of per-atom features. + + Returns + ------- + None + """ + super().__init__() + + # expend embedding vector + self.append_to_embedding_tensor = nn.ModuleList() + self.registered_appended_properties: List[str] = [] + # what other categorial properties are embedded + self.embeddings = nn.ModuleList() + self.registered_embedding_operations: List[str] = [] + + self.increase_dim_of_embedded_tensor: int = 0 + base_embedding_dim = int( + featurization_config["atomic_number"]["number_of_per_atom_features"] + ) + properties_to_featurize = featurization_config["properties_to_featurize"] + # iterate through the supported featurization types and check if one of + # these is requested + for featurization in properties_to_featurize: + + # embed atomic number + if ( + featurization == "atomic_number" + and featurization in self._SUPPORTED_FEATURIZATION_TYPES + ): + self.atomic_number_embedding = torch.nn.Embedding( + int(featurization_config[featurization]["maximum_atomic_number"]), + int( + featurization_config[featurization][ + "number_of_per_atom_features" + ] + ), + ) + self.registered_embedding_operations.append("atomic_number") + + # add total charge to embedding vector + elif ( + featurization == "per_system_total_charge" + and featurization in self._SUPPORTED_FEATURIZATION_TYPES + ): + # transform output o f embedding with shape (nr_atoms, + # nr_features) to (nr_atoms, nr_features + 1). The added + # features is the total charge (which will be transformed to a + # per-atom property) + self.append_to_embedding_tensor.append( + AddPerMoleculeValue("per_system_total_charge") + ) + self.increase_dim_of_embedded_tensor += 1 + self.registered_appended_properties.append("per_system_total_charge") + + # add partial charge to embedding vector + elif ( + featurization == "per_atom_partial_charge" + and featurization in self._SUPPORTED_FEATURIZATION_TYPES + ): # transform output of embedding with shape (nr_atoms, nr_features) to (nr_atoms, nr_features + 1). + # #The added features is the total charge (which will be + # transformed to a per-atom property) + self.append_to_embedding_tensor.append( + AddPerAtomValue("partial_charge") + ) + self.increase_dim_of_embedded_tensor += 1 + self.append_to_embedding_tensor("partial_charge") + + else: + raise RuntimeError( + f"Unsupported featurization type {featurization}. Supported types are {self._SUPPORTED_FEATURIZATION_TYPES}" + ) + + # if only nuclear charges are embedded no mixing is performed + self.mixing: Union[nn.Identity, DenseWithCustomDist] + if self.increase_dim_of_embedded_tensor == 0: + self.mixing = nn.Identity() + else: + self.mixing = DenseWithCustomDist( + base_embedding_dim + self.increase_dim_of_embedded_tensor, + base_embedding_dim, + ) + + def forward(self, data: NNPInput) -> torch.Tensor: + """ + Featurize the input data. + + Parameters + ---------- + data : NNPInput + The input data. + + Returns + ------- + torch.Tensor + The featurized input data. + """ + atomic_numbers = data.atomic_numbers + categorial_embedding = self.atomic_number_embedding(atomic_numbers) + if torch.isnan(categorial_embedding).any(): + raise ValueError("NaN values detected in categorial_embedding.") + + for additional_embedding in self.embeddings: + categorial_embedding = additional_embedding(categorial_embedding, data) + + for append_embedding_vector in self.append_to_embedding_tensor: + categorial_embedding = append_embedding_vector(categorial_embedding, data) + + return self.mixing(categorial_embedding) diff --git a/modelforge/potential/models.py b/modelforge/potential/models.py deleted file mode 100644 index ac31b16d..00000000 --- a/modelforge/potential/models.py +++ /dev/null @@ -1,1112 +0,0 @@ -from abc import ABC, abstractmethod -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Mapping, - NamedTuple, - Tuple, - Type, - Optional, - List, -) - -import lightning as pl -import torch -from loguru import logger as log -from openff.units import unit -from torch.nn import Module - -from modelforge.dataset.dataset import NNPInput - -if TYPE_CHECKING: - from modelforge.potential.ani import ANI2x, AniNeuralNetworkData - from modelforge.potential.painn import PaiNN, PaiNNNeuralNetworkData - from modelforge.potential.physnet import PhysNet, PhysNetNeuralNetworkData - from modelforge.potential.sake import SAKE, SAKENeuralNetworkInput - from modelforge.potential.schnet import SchNet, SchnetNeuralNetworkData - - -# Define NamedTuple for the outputs of Pairlist and Neighborlist forward method -class PairListOutputs(NamedTuple): - """ - A namedtuple to store the outputs of the Pairlist and Neighborlist forward methods. - - Attributes: - pair_indices (torch.Tensor): A tensor of shape (2, n_pairs) containing the indices of the interacting atom pairs. - d_ij (torch.Tensor): A tensor of shape (n_pairs, 1) containing the Euclidean distances between the atoms in each pair. - r_ij (torch.Tensor): A tensor of shape (n_pairs, 3) containing the displacement vectors between the atoms in each pair. - """ - - pair_indices: torch.Tensor - d_ij: torch.Tensor - r_ij: torch.Tensor - - -class Pairlist(Module): - """ - Handle pair list calculations for atoms, returning atom indices pairs and displacement vectors. - - Attributes: - only_unique_pairs (bool): If True, only unique pairs are returned (default is False). - Otherwise, all pairs are returned. - """ - - def __init__(self, only_unique_pairs: bool = False): - """ - Initialize the Pairlist object. - - Parameters: - only_unique_pairs (bool, optional): If True, only unique pairs are returned (default is False). - Otherwise, all pairs are returned. - """ - super().__init__() - self.only_unique_pairs = only_unique_pairs - - def enumerate_all_pairs(self, atomic_subsystem_indices: torch.Tensor): - """Compute all pairs of atoms and their distances. - - Parameters - ---------- - atomic_subsystem_indices : torch.Tensor, shape (nr_atoms_per_systems) - Atom indices to indicate which atoms belong to which molecule - Note in all cases, the values in this tensor must be numbered from 0 to n_molecules - 1 - sequentially, with no gaps in the numbering. E.g., [0,0,0,1,1,2,2,2 ...]. - This is the case for all internal data structures, and those no validation is performed in - this routine. If the data is not structured in this way, the results will be incorrect. - - """ - - # get device that passed tensors lives on, initialize on the same device - device = atomic_subsystem_indices.device - - # if there is only one molecule, we do not need to use additional looping and offsets - if torch.sum(atomic_subsystem_indices) == 0: - n = len(atomic_subsystem_indices) - if self.only_unique_pairs: - i_final_pairs, j_final_pairs = torch.triu_indices( - n, n, 1, device=device - ) - else: - # Repeat each number n-1 times for i_indices - i_final_pairs = torch.repeat_interleave( - torch.arange(n, device=device), - repeats=n - 1, - ) - - # Correctly construct j_indices - j_final_pairs = torch.cat( - [ - torch.cat( - ( - torch.arange(i, device=device), - torch.arange(i + 1, n, device=device), - ) - ) - for i in range(n) - ] - ) - - else: - # if we have more than one molecule, we will take into account molecule size and offsets when - # calculating pairs, as using the approach above is not memory efficient for datasets with large molecules - # and/or larger batch sizes; while not likely a problem on higher end GPUs with large amounts of memory - # cheaper commodity and mobile GPUs may have issues - - # atomic_subsystem_indices are always numbered from 0 to n_molecules - 1 - # e.g., a single molecule will be [0, 0, 0, 0 ... ] - # and a batch of molecules will always start at 0 and increment [ 0, 0, 0, 1, 1, 1, ...] - # As such, we can use bincount, as there are no gaps in the numbering - # Note if the indices are not numbered from 0 to n_molecules - 1, this will not work - # E.g., bincount on [3,3,3, 4,4,4, 5,5,5] will return [0,0,0,3,3,3,3,3,3] - # as we have no values for 0, 1, 2 - # using a combination of unique and argsort would make this work for any numbering ordering - # but that is not how the data ends up being structured internally, and thus is not needed - repeats = torch.bincount(atomic_subsystem_indices) - offsets = torch.cat( - (torch.tensor([0], device=device), torch.cumsum(repeats, dim=0)[:-1]) - ) - - i_indices = torch.cat( - [ - torch.repeat_interleave( - torch.arange(o, o + r, device=device), repeats=r - ) - for r, o in zip(repeats, offsets) - ] - ) - j_indices = torch.cat( - [ - torch.cat([torch.arange(o, o + r, device=device) for _ in range(r)]) - for r, o in zip(repeats, offsets) - ] - ) - - if self.only_unique_pairs: - # filter out pairs that are not unique - unique_pairs_mask = i_indices < j_indices - i_final_pairs = i_indices[unique_pairs_mask] - j_final_pairs = j_indices[unique_pairs_mask] - else: - # filter out identical values - unique_pairs_mask = i_indices != j_indices - i_final_pairs = i_indices[unique_pairs_mask] - j_final_pairs = j_indices[unique_pairs_mask] - - # concatenate to form final (2, n_pairs) tensor - pair_indices = torch.stack((i_final_pairs, j_final_pairs)) - - return pair_indices.to(device) - - def construct_initial_pairlist_using_numpy( - self, atomic_subsystem_indices: torch.Tensor - ): - """Compute all pairs of atoms and also return counts of the number of pairs for each molecule in batch. - - Parameters - ---------- - atomic_subsystem_indices : torch.Tensor, shape (nr_atoms_per_systems) - Atom indices to indicate which atoms belong to which molecule - Note in all cases, the values in this tensor must be numbered from 0 to n_molecules - 1 - sequentially, with no gaps in the numbering. E.g., [0,0,0,1,1,2,2,2 ...]. - This is the case for all internal data structures, and those no validation is performed in - this routine. If the data is not structured in this way, the results will be incorrect. - Returns - ------- - pair_indices : np.ndarray, shape (2, n_pairs) - Pairs of atom indices, 0-indexed for each molecule - number_of_pairs : np.ndarray, shape (n_molecules) - The number to index into pair_indices for each molecule - - """ - - # atomic_subsystem_indices are always numbered from 0 to n_molecules - 1 - # e.g., a single molecule will be [0, 0, 0, 0 ... ] - # and a batch of molecules will always start at 0 and increment [ 0, 0, 0, 1, 1, 1, ...] - # As such, we can use bincount, as there are no gaps in the numbering - # Note if the indices are not numbered from 0 to n_molecules - 1, this will not work - # E.g., bincount on [3,3,3, 4,4,4, 5,5,5] will return [0,0,0,3,3,3,3,3,3] - # as we have no values for 0, 1, 2 - # using a combination of unique and argsort would make this work for any numbering ordering - # but that is not how the data ends up being structured internally, and thus is not needed - - import numpy as np - - # get the number of atoms in each molecule - repeats = np.bincount(atomic_subsystem_indices) - - # calculate the number of pairs for each molecule, using simple permutation - npairs_by_molecule = np.array([r * (r - 1) for r in repeats], dtype=np.int16) - - i_indices = np.concatenate( - [ - np.repeat( - np.arange( - 0, - r, - dtype=np.int16, - ), - repeats=r, - ) - for r in repeats - ] - ) - j_indices = np.concatenate( - [ - np.concatenate([np.arange(0, 0 + r, dtype=np.int16) for _ in range(r)]) - for r in repeats - ] - ) - - # filter out identical pairs where i==j - unique_pairs_mask = i_indices != j_indices - i_final_pairs = i_indices[unique_pairs_mask] - j_final_pairs = j_indices[unique_pairs_mask] - - # concatenate to form final (2, n_pairs) vector - pair_indices = np.stack((i_final_pairs, j_final_pairs)) - - return pair_indices, npairs_by_molecule - - def calculate_r_ij( - self, pair_indices: torch.Tensor, positions: torch.Tensor - ) -> torch.Tensor: - """Compute displacement vectors between atom pairs. - - Parameters - ---------- - pair_indices : torch.Tensor - Atom indices for pairs of atoms. Shape: [2, n_pairs]. - positions : torch.Tensor - Atom positions. Shape: [atoms, 3]. - - Returns - ------- - torch.Tensor - Displacement vectors between atom pairs. Shape: [n_pairs, 3]. - """ - # Select the pairs of atom coordinates from the positions - selected_positions = positions.index_select(0, pair_indices.view(-1)).view( - 2, -1, 3 - ) - return selected_positions[1] - selected_positions[0] - - def calculate_d_ij(self, r_ij: torch.Tensor) -> torch.Tensor: - """Compute Euclidean distances between atoms in each pair. - - Parameters - ---------- - r_ij : torch.Tensor - Displacement vectors between atoms in a pair. Shape: [n_pairs, 3]. - - Returns - ------- - torch.Tensor - Euclidean distances. Shape: [n_pairs, 1]. - """ - return r_ij.norm(dim=1).unsqueeze(1) - - def forward( - self, - positions: torch.Tensor, - atomic_subsystem_indices: torch.Tensor, - ) -> PairListOutputs: - """ - Performs the forward pass of the Pairlist module. - - Parameters - ---------- - positions : torch.Tensor - Atom positions. Shape: [nr_atoms, 3]. - atomic_subsystem_indices (torch.Tensor, shape (nr_atoms_per_systems)): - Atom indices to indicate which atoms belong to which molecule. - - Returns - ------- - PairListOutputs: A namedtuple containing the following attributes: - pair_indices (torch.Tensor): A tensor of shape (2, n_pairs) containing the indices of the interacting atom pairs. - d_ij (torch.Tensor): A tensor of shape (n_pairs, 1) containing the Euclidean distances between the atoms in each pair. - r_ij (torch.Tensor): A tensor of shape (n_pairs, 3) containing the displacement vectors between the atoms in each pair. - """ - pair_indices = self.enumerate_all_pairs( - atomic_subsystem_indices, - ) - r_ij = self.calculate_r_ij(pair_indices, positions) - - return PairListOutputs( - pair_indices=pair_indices, - d_ij=self.calculate_d_ij(r_ij), - r_ij=r_ij, - ) - - -class Neighborlist(Pairlist): - """Manage neighbor list calculations with a specified cutoff distance. - - This class extends Pairlist to consider a cutoff distance for neighbor calculations. - """ - - def __init__(self, cutoff: unit.Quantity, only_unique_pairs: bool = False): - """ - Initialize the Neighborlist with a specific cutoff distance. - - Parameters - ---------- - cutoff : unit.Quantity - Cutoff distance for neighbor calculations. - """ - super().__init__(only_unique_pairs=only_unique_pairs) - self.register_buffer("cutoff", torch.tensor(cutoff.to(unit.nanometer).m)) - - def forward( - self, - positions: torch.Tensor, - atomic_subsystem_indices: torch.Tensor, - pair_indices: Optional[torch.Tensor] = None, - ) -> PairListOutputs: - """ - Forward pass to compute neighbor list considering a cutoff distance. - - Overrides the `forward` method from Pairlist to include cutoff distance in calculations. - - Parameters - ---------- - positions : torch.Tensor - Atom positions. Shape: [nr_systems, nr_atoms, 3]. - atomic_subsystem_indices : torch.Tensor - Indices identifying atoms in subsystems. Shape: [nr_atoms]. - - Returns - ------- - PairListOutputs - A NamedTuple containing 'pair_indices', 'd_ij' (distances), and 'r_ij' (displacement vectors). - """ - - if pair_indices is None: - pair_indices = self.enumerate_all_pairs( - atomic_subsystem_indices, - ) - - r_ij = self.calculate_r_ij(pair_indices, positions) - d_ij = self.calculate_d_ij(r_ij) - - # Find pairs within the cutoff - in_cutoff = (d_ij <= self.cutoff).squeeze() - # Get the atom indices within the cutoff - pair_indices_within_cutoff = pair_indices[:, in_cutoff] - - return PairListOutputs( - pair_indices=pair_indices_within_cutoff, - d_ij=d_ij[in_cutoff], - r_ij=r_ij[in_cutoff], - ) - - -from typing import Callable, Literal, Optional, Union - -import numpy as np - - -class JAXModel: - """A model wrapper that facilitates calling a JAX function with predefined parameters and buffers. - - Attributes - ---------- - jax_fn : Callable - The JAX function to be called. - parameter : jax. - Parameters required by the JAX function. - buffer : Any - Buffers required by the JAX function. - name : str - Name of the model. - """ - - def __init__( - self, jax_fn: Callable, parameter: np.ndarray, buffer: np.ndarray, name: str - ): - self.jax_fn = jax_fn - self.parameter = parameter - self.buffer = buffer - self.name = name - - def __call__(self, data: NamedTuple): - """Calls the JAX function using the stored parameters and buffers along with additional data. - - Parameters - ---------- - data : NamedTuple - Data to be passed to the JAX function. - - Returns - ------- - Any - The result of the JAX function. - """ - - return self.jax_fn(self.parameter, self.buffer, data) - - def __repr__(self): - return f"{self.__class__.__name__} wrapping {self.name}" - - -class PyTorch2JAXConverter: - """ - Wraps a PyTorch neural network potential instance in a Flax module using the - `pytorch2jax` library (https://github.com/subho406/Pytorch2Jax). - The converted model uses dlpack to convert between Pytorch and Jax tensors - in-memory and executes Pytorch backend inside Jax wrapped functions. - The wrapped modules are compatible with Jax backward-mode autodiff. - - Parameters - ---------- - nnp_instance : Any - The neural network potential instance to convert. - - Returns - ------- - JAXModel - The converted JAX model. - """ - - def convert_to_jax_model( - self, nnp_instance: Union["ANI2x", "SchNet", "PaiNN", "PhysNet"] - ) -> JAXModel: - """ - Convert a PyTorch neural network instance to a JAX model. - - Parameters - ---------- - nnp_instance : Union["ANI2x", "SchNet", "PaiNN", "PhysNet"] - The PyTorch neural network instance to be converted. - - Returns - ------- - JAXModel - A JAX model containing the converted neural network function, parameters, and buffers. - """ - - jax_fn, params, buffers = self._convert_pytnn_to_jax(nnp_instance) - return JAXModel(jax_fn, params, buffers, nnp_instance.__class__.__name__) - - @staticmethod - def _convert_pytnn_to_jax( - nnp_instance: Union["ANI2x", "SchNet", "PaiNN", "PhysNet"] - ) -> Tuple[Callable, np.ndarray, np.ndarray]: - """Internal method to convert PyTorch neural network parameters and buffers to JAX format. - - Parameters - ---------- - nnp_instance : Any - The PyTorch neural network instance. - - Returns - ------- - Tuple[Callable, Any, Any] - A tuple containing the JAX function, parameters, and buffers. - """ - - # make sure - from modelforge.utils.io import import_ - - jax = import_("jax") - # use the wrapper to check if pytorch2jax is in the environment - - custom_vjp = import_("jax").custom_vjp - - # from jax import custom_vjp - convert_to_jax = import_("pytorch2jax").pytorch2jax.convert_to_jax - convert_to_pyt = import_("pytorch2jax").pytorch2jax.convert_to_pyt - # from pytorch2jax.pytorch2jax import convert_to_jax, convert_to_pyt - - import functorch - from functorch import make_functional_with_buffers - - # Convert the PyTorch model to a functional representation and extract the model function and parameters - model_fn, model_params, model_buffer = make_functional_with_buffers( - nnp_instance - ) - - # Convert the model parameters from PyTorch to JAX representations - model_params = jax.tree_map(convert_to_jax, model_params) - # Convert the model buffer from PyTorch to JAX representations - model_buffer = jax.tree_map(convert_to_jax, model_buffer) - - # Define the apply function using a custom VJP - @custom_vjp - def apply(params, *args, **kwargs): - # Convert the input data from JAX to PyTorch - params, args, kwargs = map( - lambda x: jax.tree_map(convert_to_pyt, x), (params, args, kwargs) - ) - # Apply the model function to the input data - out = model_fn(params, *args, **kwargs) - # Convert the output data from PyTorch to JAX - out = jax.tree_map(convert_to_jax, out) - return out - - # Define the forward and backward passes for the VJP - def apply_fwd(params, *args, **kwargs): - return apply(params, *args, **kwargs), (params, args, kwargs) - - def apply_bwd(res, grads): - params, args, kwargs = res - params, args, kwargs = map( - lambda x: jax.tree_map(convert_to_pyt, x), (params, args, kwargs) - ) - grads = jax.tree_map(convert_to_pyt, grads) - # Compute the gradients using the model function and convert them from JAX to PyTorch representations - grads = functorch.vjp(model_fn, params, *args, **kwargs)[1](grads) - return jax.tree_map(convert_to_jax, grads) - - apply.defvjp(apply_fwd, apply_bwd) - - # Return the apply function and the converted model parameters - return apply, model_params, model_buffer - - -class NeuralNetworkPotentialFactory: - """ - Factory class for creating instances of neural network potentials for training/inference. - """ - - @staticmethod - def generate_model( - *, - use: Literal["training", "inference"], - model_parameter: Dict[str, Union[str, Any]], - simulation_environment: Literal["PyTorch", "JAX"] = "PyTorch", - training_parameter: Optional[Dict[str, Any]] = None, - dataset_statistic: Optional[Dict[str, float]] = None, - ) -> Union[Type[torch.nn.Module], Type[JAXModel], Type[pl.LightningModule]]: - """ - Creates an NNP instance of the specified type, configured either for training or inference. - - Parameters - ---------- - use : str - The use case for the model instance, either 'training' or 'inference'. - simulation_environment : str - The ML framework to use, either 'PyTorch' or 'JAX'. - model_parameter : dict, optional - Parameters specific to the model, by default {}. - training_parameter : dict, optional - Parameters for configuring the training, by default {}. - - Returns - ------- - Union[Union[torch.nn.Module], pl.LightningModule, JAXModel] - An instantiated model. - - Raises - ------ - ValueError - If an unknown use case is requested. - NotImplementedError - If the requested model type is not implemented. - """ - - from modelforge.potential import _Implemented_NNPs - from modelforge.train.training import TrainingAdapter - - # obtain model for training - if use == "training": - if simulation_environment == "JAX": - log.warning( - "Training in JAX is not availalbe. Falling back to PyTorch." - ) - model = TrainingAdapter( - model_parameter=model_parameter, - **training_parameter, - dataset_statistic=dataset_statistic, - ) - return model - # obtain model for inference - elif use == "inference": - model_type = model_parameter["model_name"] - nnp_class: Type = _Implemented_NNPs.get_neural_network_class(model_type) - model = nnp_class( - **model_parameter["core_parameter"], - postprocessing_parameter=model_parameter["postprocessing_parameter"], - dataset_statistic=dataset_statistic, - ) - if simulation_environment == "JAX": - return PyTorch2JAXConverter().convert_to_jax_model(model) - else: - return model - else: - raise NotImplementedError(f"Unsupported 'use' value: {use}") - - -class InputPreparation(torch.nn.Module): - def __init__(self, cutoff: unit.Quantity, only_unique_pairs: bool = True): - """ - A module for preparing input data, including the calculation of pair lists, distances (d_ij), and displacement vectors (r_ij) for molecular simulations. - Parameters - ---------- - cutoff : unit.Quantity - The cutoff distance for neighbor list calculations. - only_unique_pairs : bool, optional - Whether to only use unique pairs in the pair list calculation, by default True. This should be set to True for all message passing networks. - - """ - - super().__init__() - from .models import Neighborlist - - self.only_unique_pairs = only_unique_pairs - self.calculate_distances_and_pairlist = Neighborlist(cutoff, only_unique_pairs) - - def prepare_inputs(self, data: Union[NNPInput, NamedTuple]): - """ - Prepares the input tensors for passing to the model. - - This method handles general input manipulation, such as calculating distances - and generating the pair list. It also calls the model-specific input preparation. - - Parameters - ---------- - data : Union[NNPInput, NamedTuple] - The input data provided by the dataset, containing atomic numbers, positions, and other necessary information. - - Returns - ------- - PairListOutputs - A namedtuple containing the pair indices, Euclidean distances (d_ij), and displacement vectors (r_ij). - """ - # --------------------------- - # general input manipulation - positions = data.positions - atomic_subsystem_indices = data.atomic_subsystem_indices - # calculate pairlist if none is provided - if data.pair_list is None: - pairlist_output = self.calculate_distances_and_pairlist( - positions=positions, - atomic_subsystem_indices=atomic_subsystem_indices, - pair_indices=None, - ) - pair_list = data.pair_list - else: - # pairlist is provided, remove redundant pairs if requested - if self.only_unique_pairs: - i_indices = data.pair_list[0] - j_indices = data.pair_list[1] - unique_pairs_mask = i_indices < j_indices - i_final_pairs = i_indices[unique_pairs_mask] - j_final_pairs = j_indices[unique_pairs_mask] - pair_list = torch.stack((i_final_pairs, j_final_pairs)) - else: - pair_list = data.pair_list - # only calculate d_ij and r_ij - pairlist_output = self.calculate_distances_and_pairlist( - positions=positions, - atomic_subsystem_indices=atomic_subsystem_indices, - pair_indices=pair_list.to(torch.int64), - ) - - return pairlist_output - - def _input_checks(self, data: Union[NNPInput, NamedTuple]): - """ - Performs input validation checks. - - Ensures the input data conforms to expected shapes and types. - - Parameters - ---------- - data : NNPInput - The input data to be validated. - - Raises - ------ - ValueError - If the input data does not meet the expected criteria. - """ - # check that the input is instance of NNPInput - assert isinstance(data, NNPInput) or isinstance(data, Tuple) - - nr_of_atoms = data.atomic_numbers.shape[0] - assert data.atomic_numbers.shape == torch.Size([nr_of_atoms]) - assert data.atomic_subsystem_indices.shape == torch.Size([nr_of_atoms]) - nr_of_molecules = torch.unique(data.atomic_subsystem_indices).numel() - assert data.total_charge.shape == torch.Size([nr_of_molecules]) - assert data.positions.shape == torch.Size([nr_of_atoms, 3]) - - -from torch.nn import ModuleDict - - -class PostProcessing(torch.nn.Module): - """ - A module for handling post-processing operations on model outputs, including normalization, calculation of atomic self-energies, and reduction operations to compute per-molecule properties from per-atom properties. - """ - - _SUPPORTED_PROPERTIES = ["per_atom_energy", "general_postprocessing_operation"] - _SUPPORTED_OPERATIONS = ["normalize", "from_atom_to_molecule_reduction"] - - def __init__( - self, - postprocessing_parameter: Dict[str, Dict[str, bool]], - dataset_statistic: Dict[str, Dict[str, float]], - ): - """ - Parameters - ---------- - postprocessing_parameter: Dict[str, Dict[str, bool]] # TODO: update - dataset_statistic : Dict[str, float] - A dictionary containing the dataset statistics for normalization and other calculations. - """ - super().__init__() - - self._registered_properties: List[str] = [] - - # operations that use nn.Sequence to pass the output of the model to the next - self.registered_chained_operations = ModuleDict() - - self.dataset_statistic = dataset_statistic - - self._initialize_postprocessing( - postprocessing_parameter, - ) - - def _get_per_atom_energy_mean_and_stddev_of_dataset(self) -> Tuple[float, float]: - """ - Calculate the mean and standard deviation of the per-atom energy in the dataset. - - Returns - ------- - Tuple[float, float] - The mean and standard deviation of the per-atom energy. - """ - if self.dataset_statistic is None: - mean = 0.0 - stddev = 1.0 - log.warning( - f"No mean and stddev provided for dataset. Setting to default value {mean=} and {stddev=}!" - ) - else: - training_dataset_statistics = self.dataset_statistic[ - "training_dataset_statistics" - ] - mean = unit.Quantity( - training_dataset_statistics["per_atom_energy_mean"] - ).m_as(unit.kilojoule_per_mole) - stddev = unit.Quantity( - training_dataset_statistics["per_atom_energy_stddev"] - ).m_as(unit.kilojoule_per_mole) - return mean, stddev - - def _initialize_postprocessing( - self, - postprocessing_parameter: Dict[str, Dict[str, bool]], - ): - """ - Initialize the postprocessing operations based on the given postprocessing parameters. - - Parameters: - postprocessing_parameter (Dict[str, Dict[str, bool]]): A dictionary containing the postprocessing parameters for each property. - - Raises: - ValueError: If a property is not supported. - - Returns: - None - """ - - from .processing import ( - FromAtomToMoleculeReduction, - ScaleValues, - CalculateAtomicSelfEnergy, - ) - - for property, operations in postprocessing_parameter.items(): - # register properties for which postprocessing should be performed - if property.lower() in self._SUPPORTED_PROPERTIES: - self._registered_properties.append(property.lower()) - else: - raise ValueError( - f"Property {property} is not supported. Supported properties are {self._SUPPORTED_PROPERTIES}" - ) - - # register operations that are performed for the property - postprocessing_sequence = torch.nn.Sequential() - prostprocessing_sequence_names = [] - - # for each property parse the requested operations - if property == "per_atom_energy": - if operations.get("normalize", False): - mean, stddev = ( - self._get_per_atom_energy_mean_and_stddev_of_dataset() - ) - postprocessing_sequence.append( - ScaleValues( - mean=mean, - stddev=stddev, - property="per_atom_energy", - output_name="per_atom_energy", - ) - ) - prostprocessing_sequence_names.append("normalize") - # check if also reduction is requested - if operations.get("from_atom_to_molecule_reduction", False): - postprocessing_sequence.append( - FromAtomToMoleculeReduction( - per_atom_property_name="per_atom_energy", - index_name="atomic_subsystem_indices", - output_name="per_molecule_energy", - keep_per_atom_property=operations.get( - "keep_per_atom_property", False - ), - ) - ) - prostprocessing_sequence_names.append( - "from_atom_to_molecule_reduction" - ) - elif property == "general_postprocessing_operation": - # check if also self-energies are requested - if operations.get("calculate_molecular_self_energy", False): - - if self.dataset_statistic is None: - log.warning( - "Dataset statistics are required to calculate the molecular self-energies but haven't been provided." - ) - else: - atomic_self_energies = self.dataset_statistic[ - "atomic_self_energies" - ] - - postprocessing_sequence.append( - CalculateAtomicSelfEnergy(atomic_self_energies) - ) - prostprocessing_sequence_names.append( - "calculate_molecular_self_energy" - ) - - postprocessing_sequence.append( - FromAtomToMoleculeReduction( - per_atom_property_name="ase_tensor", - index_name="atomic_subsystem_indices", - output_name="per_molecule_self_energy", - ) - ) - - # check if also self-energies are requested - elif operations.get("calculate_atomic_self_energy", False): - if self.dataset_statistic is None: - log.warning( - "Dataset statistics are required to calculate the molecular self-energies but haven't been provided." - ) - else: - atomic_self_energies = self.dataset_statistic[ - "atomic_self_energies" - ] - - postprocessing_sequence.append( - CalculateAtomicSelfEnergy(atomic_self_energies)() - ) - prostprocessing_sequence_names.append( - "calculate_atomic_self_energy" - ) - - log.debug(prostprocessing_sequence_names) - - self.registered_chained_operations[property] = postprocessing_sequence - - def forward(self, data: Dict[str, torch.Tensor]): - """ - Perform post-processing operations for all registered properties. - """ - - # NOTE: this is not very elegant, but I am unsure how to do this better - # I am currently directly writing new keys and values in the data dictionary - for property in PostProcessing._SUPPORTED_PROPERTIES: - if property in self._registered_properties: - self.registered_chained_operations[property](data) - - return data - - -class BaseNetwork(Module): - def __init__( - self, - *, - postprocessing_parameter: Dict[str, Dict[str, bool]], - dataset_statistic, - cutoff: unit.Quantity, - ): - """ - The BaseNetwork wraps the input preparation (including pairlist calculation, d_ij and r_ij calculation), the actual model as well as the output preparation in a wrapper class. - - Learned parameters are present only in the core model, the input preparation and output preparation are not learned. - - Parameters - ---------- - postprocessing_parameter : Dict[str, Dict[str, bool]] # TODO: update - """ - - super().__init__() - from modelforge.utils.units import _convert - - self.postprocessing = PostProcessing( - postprocessing_parameter, dataset_statistic - ) - - # check if self.only_unique_pairs is set in child class - if not hasattr(self, "only_unique_pairs"): - raise RuntimeError( - "The only_unique_pairs attribute is not set in the child class. Please set it to True or False before calling super().__init__." - ) - self.input_preparation = InputPreparation( - cutoff=_convert(cutoff), only_unique_pairs=self.only_unique_pairs - ) - - def load_state_dict( - self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False - ): - """ - Load the state dictionary into the model, with optional prefix removal and key exclusions. - - Parameters - ---------- - state_dict : Mapping[str, Any] - The state dictionary to load. - strict : bool, optional - Whether to strictly enforce that the keys in `state_dict` match the keys returned by this module's `state_dict()` function (default is True). - assign : bool, optional - Whether to assign the state dictionary to the model directly (default is False). - - Notes - ----- - - This function can remove a specific prefix from the keys in the state dictionary. - - It can also exclude certain keys from being loaded into the model. - """ - - # Prefix to remove - prefix = "model." - excluded_keys = ["loss.per_molecule_energy", "loss.per_atom_force"] - - # Create a new dictionary without the prefix in the keys if prefix exists - if any(key.startswith(prefix) for key in state_dict.keys()): - filtered_state_dict = { - key[len(prefix) :] if key.startswith(prefix) else key: value - for key, value in state_dict.items() - if key not in excluded_keys - } - log.debug(f"Removed prefix: {prefix}") - else: - # Create a filtered dictionary without excluded keys if no prefix exists - filtered_state_dict = { - k: v for k, v in state_dict.items() if k not in excluded_keys - } - log.debug("No prefix found. No modifications to keys in state loading.") - - super().load_state_dict(filtered_state_dict, strict=strict, assign=assign) - - def prepare_input(self, data): - - self.input_preparation._input_checks(data) - return self.input_preparation.prepare_inputs(data) - - def compute(self, data, core_input): - return self.core_module(data, core_input) - - def forward(self, data: NNPInput): - """ - Executes the forward pass of the model. - This method performs input checks, prepares the inputs, - and computes the outputs using the core network. - - Parameters - ---------- - data : NNPInput - The input data provided by the dataset, containing atomic numbers, positions, and other necessary information. - - Returns - ------- - Any - The outputs computed by the core network. - """ - - # perform input checks - core_input = self.prepare_input(data) - # prepare the input for the forward pass - output = self.compute(data, core_input) - # perform postprocessing operations - processed_output = self.postprocessing(output) - return processed_output - - -class CoreNetwork(Module, ABC): - def __init__( - self, - ): - """ - The CoreNetwork implements methods that are used by all neural network potentials. Every network inherits from CoreNetwork. - Networks are taking in a NNPInput and pairlist and returning a dictionary of **atomic** properties. - - Operations that are performed outside the network (e.g. pairlist calculation and operations that reduce atomic properties to molecule properties) are not part of the network and implemented in the BaseNetwork, which is a wrapper around the CoreNetwork. - """ - - super().__init__() - self._dtype: Optional[bool] = None # set at runtime - - @abstractmethod - def _model_specific_input_preparation( - self, data: NNPInput, pairlist: PairListOutputs - ) -> Union[ - "PhysNetNeuralNetworkData", - "PaiNNNeuralNetworkData", - "SchnetNeuralNetworkData", - "AniNeuralNetworkData", - "SAKENeuralNetworkInput", - ]: - """ - Prepares model-specific inputs before the forward pass. - - This abstract method should be implemented by subclasses to accommodate any - model-specific preprocessing of inputs. - - Parameters - ---------- - data : NNPInput - The initial inputs to the neural network model, including atomic numbers, positions, and other relevant data. - pairlist : PairListOutputs - The outputs of a pairlist calculation, including pair indices, distances, and displacement vectors. - - Returns - ------- - NeuralNetworkData - The processed inputs, ready for the model's forward pass. - """ - pass - - @abstractmethod - def compute_properties( - self, - data: Union[ - "PhysNetNeuralNetworkData", - "PaiNNNeuralNetworkData", - "SchnetNeuralNetworkData", - "AniNeuralNetworkData", - "SAKENeuralNetworkInput", - ], - ): - """ - Defines the forward pass of the model. - - This abstract method should be implemented by subclasses to specify the model's computation from inputs (processed input data) to outputs (per atom properties). - - Parameters - ---------- - data : The processed input data, specific to the model's requirements. - - Returns - ------- - Any - The model's output as computed from the inputs. - """ - pass - - def load_pretrained_weights(self, path: str): - """ - Loads pretrained weights into the model from the specified path. - - Parameters - ---------- - path : str - The path to the file containing the pretrained weights. - - Returns - ------- - None - """ - self.load_state_dict(torch.load(path, map_location=self.device)) - self.eval() # Set the model to evaluation mode - - def forward( - self, data: NNPInput, pairlist_output: PairListOutputs - ) -> Dict[str, torch.Tensor]: - """ - Implements the forward pass through the network. - - Parameters - ---------- - data : NNPInput - Contains input data for the batch obtained directly from the dataset, including atomic numbers, positions, - and other relevant fields. - pairlist_output : PairListOutputs - Contains the indices for the selected pairs and their associated distances and displacement vectors. - - Returns - ------- - Dict[str, torch.Tensor] - The calculated per-atom properties and other properties from the forward pass. - """ - # perform model specific modifications - nnp_input = self._model_specific_input_preparation(data, pairlist_output) - # perform the forward pass implemented in the subclass - outputs = self.compute_properties(nnp_input) - # add atomic numbers to the output - outputs["atomic_numbers"] = data.atomic_numbers - - return outputs diff --git a/modelforge/potential/neighbors.py b/modelforge/potential/neighbors.py new file mode 100644 index 00000000..b9b6ae90 --- /dev/null +++ b/modelforge/potential/neighbors.py @@ -0,0 +1,867 @@ +""" +This file contains classes for computing pairs and neighbors. +""" + +import torch +from loguru import logger as log +from modelforge.dataset.dataset import NNPInput + +from typing import Union, NamedTuple + + +class PairlistData(NamedTuple): + """ + A namedtuple to store the outputs of the Pairlist and Neighborlist forward methods. + + Attributes + ---------- + pair_indices : torch.Tensor + A tensor of shape (2, n_pairs) containing the indices of the interacting atom pairs. + d_ij : torch.Tensor + A tensor of shape (n_pairs, 1) containing the Euclidean distances between the atoms in each pair. + r_ij : torch.Tensor + A tensor of shape (n_pairs, 3) containing the displacement vectors between the atoms in each pair. + """ + + pair_indices: torch.Tensor + d_ij: torch.Tensor + r_ij: torch.Tensor + + +class Pairlist(torch.nn.Module): + def __init__(self, only_unique_pairs: bool = False): + """ + Handle pair list calculations for systems, returning indices, distances + and distance vectors for atom pairs within a certain cutoff. + + Parameters + ---------- + only_unique_pairs : bool, optional + If True, only unique pairs are returned (default is False). + """ + super().__init__() + self.only_unique_pairs = only_unique_pairs + + def enumerate_all_pairs(self, atomic_subsystem_indices: torch.Tensor): + """ + Compute all pairs of atoms and their distances. + + Parameters + ---------- + atomic_subsystem_indices : torch.Tensor + Atom indices to indicate which atoms belong to which molecule. + Note in all cases, the values in this tensor must be numbered from 0 to n_molecules - 1 sequentially, with no gaps in the numbering. E.g., [0,0,0,1,1,2,2,2 ...]. + This is the case for all internal data structures, and thus no validation is performed in this routine. If the data is not structured in this way, the results will be incorrect. + + Returns + ------- + torch.Tensor + Pair indices for all atom pairs. + """ + + # get device that passed tensors lives on, initialize on the same device + device = atomic_subsystem_indices.device + + # if there is only one molecule, we do not need to use additional looping and offsets + if torch.sum(atomic_subsystem_indices) == 0: + n = len(atomic_subsystem_indices) + if self.only_unique_pairs: + i_final_pairs, j_final_pairs = torch.triu_indices( + n, n, 1, device=device + ) + else: + # Repeat each number n-1 times for i_indices + i_final_pairs = torch.repeat_interleave( + torch.arange(n, device=device), + repeats=n - 1, + ) + + # Correctly construct j_indices + j_final_pairs = torch.cat( + [ + torch.cat( + ( + torch.arange(i, device=device), + torch.arange(i + 1, n, device=device), + ) + ) + for i in range(n) + ] + ) + + else: + # if we have more than one molecule, we will take into account molecule size and offsets when + # calculating pairs, as using the approach above is not memory efficient for datasets with large molecules + # and/or larger batch sizes; while not likely a problem on higher end GPUs with large amounts of memory + # cheaper commodity and mobile GPUs may have issues + + # atomic_subsystem_indices are always numbered from 0 to n_molecules + # - 1 e.g., a single molecule will be [0, 0, 0, 0 ... ] and a batch + # of molecules will always start at 0 and increment [ 0, 0, 0, 1, 1, + # 1, ...] As such, we can use bincount, as there are no gaps in the + # numbering + + # Note if the indices are not numbered from 0 to n_molecules - 1, this will not work + # E.g., bincount on [3,3,3, 4,4,4, 5,5,5] will return [0,0,0,3,3,3,3,3,3] + # as we have no values for 0, 1, 2 + # using a combination of unique and argsort would make this work for any numbering ordering + # but that is not how the data ends up being structured internally, and thus is not needed + repeats = torch.bincount(atomic_subsystem_indices) + offsets = torch.cat( + (torch.tensor([0], device=device), torch.cumsum(repeats, dim=0)[:-1]) + ) + + i_indices = torch.cat( + [ + torch.repeat_interleave( + torch.arange(o, o + r, device=device), repeats=r + ) + for r, o in zip(repeats, offsets) + ] + ) + j_indices = torch.cat( + [ + torch.cat([torch.arange(o, o + r, device=device) for _ in range(r)]) + for r, o in zip(repeats, offsets) + ] + ) + + if self.only_unique_pairs: + # filter out pairs that are not unique + unique_pairs_mask = i_indices < j_indices + i_final_pairs = i_indices[unique_pairs_mask] + j_final_pairs = j_indices[unique_pairs_mask] + else: + # filter out identical values + unique_pairs_mask = i_indices != j_indices + i_final_pairs = i_indices[unique_pairs_mask] + j_final_pairs = j_indices[unique_pairs_mask] + + # concatenate to form final (2, n_pairs) tensor + pair_indices = torch.stack((i_final_pairs, j_final_pairs)) + + return pair_indices.to(device) + + def construct_initial_pairlist_using_numpy( + self, atomic_subsystem_indices: torch.Tensor + ): + """Compute all pairs of atoms and also return counts of the number of pairs for each molecule in batch. + + Parameters + ---------- + atomic_subsystem_indices : torch.Tensor + Atom indices to indicate which atoms belong to which molecule. + + Returns + ------- + pair_indices : np.ndarray, shape (2, n_pairs) + Pairs of atom indices, 0-indexed for each molecule + number_of_pairs : np.ndarray, shape (n_molecules) + The number to index into pair_indices for each molecule + + """ + + # atomic_subsystem_indices are always numbered from 0 to n_molecules - 1 + # e.g., a single molecule will be [0, 0, 0, 0 ... ] + # and a batch of molecules will always start at 0 and increment [ 0, 0, 0, 1, 1, 1, ...] + # As such, we can use bincount, as there are no gaps in the numbering + # Note if the indices are not numbered from 0 to n_molecules - 1, this will not work + # E.g., bincount on [3,3,3, 4,4,4, 5,5,5] will return [0,0,0,3,3,3,3,3,3] + # as we have no values for 0, 1, 2 + # using a combination of unique and argsort would make this work for any numbering ordering + # but that is not how the data ends up being structured internally, and thus is not needed + + import numpy as np + + # get the number of atoms in each molecule + repeats = np.bincount(atomic_subsystem_indices) + + # calculate the number of pairs for each molecule, using simple permutation + npairs_by_molecule = np.array([r * (r - 1) for r in repeats], dtype=np.int16) + + i_indices = np.concatenate( + [ + np.repeat( + np.arange( + 0, + r, + dtype=np.int16, + ), + repeats=r, + ) + for r in repeats + ] + ) + j_indices = np.concatenate( + [ + np.concatenate([np.arange(0, 0 + r, dtype=np.int16) for _ in range(r)]) + for r in repeats + ] + ) + + # filter out identical pairs where i==j + unique_pairs_mask = i_indices != j_indices + i_final_pairs = i_indices[unique_pairs_mask] + j_final_pairs = j_indices[unique_pairs_mask] + + # concatenate to form final (2, n_pairs) vector + pair_indices = np.stack((i_final_pairs, j_final_pairs)) + + return pair_indices, npairs_by_molecule + + def calculate_r_ij( + self, pair_indices: torch.Tensor, positions: torch.Tensor + ) -> torch.Tensor: + """Compute displacement vectors between atom pairs. + + Parameters + ---------- + pair_indices : torch.Tensor + Atom indices for pairs of atoms. Shape: [2, n_pairs]. + positions : torch.Tensor + Atom positions. Shape: [atoms, 3]. + + Returns + ------- + torch.Tensor + Displacement vectors between atom pairs. Shape: [n_pairs, 3]. + """ + # Select the pairs of atom coordinates from the positions + selected_positions = positions.index_select(0, pair_indices.view(-1)).view( + 2, -1, 3 + ) + return selected_positions[1] - selected_positions[0] + + def calculate_d_ij(self, r_ij: torch.Tensor) -> torch.Tensor: + """ + ompute Euclidean distances between atoms in each pair. + + Parameters + ---------- + r_ij : torch.Tensor + Displacement vectors between atoms in a pair. Shape: [n_pairs, 3]. + + Returns + ------- + torch.Tensor + Euclidean distances. Shape: [n_pairs, 1]. + """ + return r_ij.norm(dim=1).unsqueeze(1) + + def forward( + self, + positions: torch.Tensor, + atomic_subsystem_indices: torch.Tensor, + ) -> PairlistData: + """ + Performs the forward pass of the Pairlist module. + + Parameters + ---------- + positions : torch.Tensor + Atom positions. Shape: [nr_atoms, 3]. + atomic_subsystem_indices (torch.Tensor, shape (nr_atoms_per_systems)): + Atom indices to indicate which atoms belong to which molecule. + + Returns + ------- + PairListOutputs: A dataclass containing the following attributes: + pair_indices (torch.Tensor): A tensor of shape (2, n_pairs) containing the indices of the interacting atom pairs. + d_ij (torch.Tensor): A tensor of shape (n_pairs, 1) containing the Euclidean distances between the atoms in each pair. + r_ij (torch.Tensor): A tensor of shape (n_pairs, 3) containing the displacement vectors between the atoms in each pair. + """ + pair_indices = self.enumerate_all_pairs( + atomic_subsystem_indices, + ) + r_ij = self.calculate_r_ij(pair_indices, positions) + return PairlistData( + pair_indices=pair_indices, + d_ij=self.calculate_d_ij(r_ij), + r_ij=r_ij, + ) + + +class OrthogonalDisplacementFunction(torch.nn.Module): + def __init__(self): + """ + Compute displacement vectors between pairs of atoms, considering periodic boundary conditions if used. + + """ + super().__init__() + + def forward( + self, + coordinate_i: torch.Tensor, + coordinate_j: torch.Tensor, + box_vectors: torch.Tensor, + is_periodic: torch.Tensor, + ): + """ + Compute displacement vectors and Euclidean distances between atom pairs. + + Parameters + ---------- + coordinate_i : torch.Tensor + Coordinates of the first atom in each pair. Shape: [n_pairs, 3]. + coordinate_j : torch.Tensor + Coordinates of the second atom in each pair. Shape: [n_pairs, 3]. + box_vectors : torch.Tensor + Box vectors defining the periodic boundary conditions. Shape: [3, 3]. + is_periodic : bool + Whether to apply periodic boundary conditions. + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + Displacement vectors (r_ij) of shape [n_pairs, 3] and distances (d_ij) of shape [n_pairs, 1]. + """ + r_ij = coordinate_i - coordinate_j + + if is_periodic == True: + # Note, since box length may change, we need to update each time if periodic + # reinitializing this vector each time does not have a significant performance impact + + box_lengths = torch.zeros( + 3, device=box_vectors.device, dtype=box_vectors.dtype + ) + + box_lengths[0] = box_vectors[0][0] + box_lengths[1] = box_vectors[1][1] + box_lengths[2] = box_vectors[2][2] + + r_ij = ( + torch.remainder(r_ij + box_lengths / 2, box_lengths) - box_lengths / 2 + ) + + d_ij = torch.norm(r_ij, dim=1, keepdim=True, p=2) + return r_ij, d_ij + + +class NeighborlistBruteNsq(torch.nn.Module): + """ + Brute force N^2 neighbor list calculation for inference implemented fully in PyTorch. + + This is compatible with TorchScript. + + + """ + + def __init__( + self, + cutoff: float, + displacement_function: OrthogonalDisplacementFunction, + only_unique_pairs: bool = False, + ): + """ + Compute neighbor lists for inference, filtering pairs based on a cutoff distance. + + Parameters + ---------- + cutoff : float + The cutoff distance for neighbor list calculations. + displacement_function : OrthogonalDisplacementFunction + The function to calculate displacement vectors and distances between atom pairs, taking into account + the specified boundary conditions. + only_unique_pairs : bool, optional + Whether to only use unique pairs in the pair list calculation, by default False. + """ + + super().__init__() + + self.register_buffer("cutoff", torch.tensor(cutoff)) + self.register_buffer("only_unique_pairs", torch.tensor(only_unique_pairs)) + self.displacement_function = displacement_function + + self.indices = torch.tensor([]) + self.i_final_pairs = torch.tensor([]) + self.j_final_pairs = torch.tensor([]) + log.info("Initializing Brute Force N^2 Neighborlist") + + def _copy_to_nonunique( + self, + i_pairs: torch.Tensor, + j_pairs: torch.Tensor, + d_ij: torch.Tensor, + r_ij: torch.Tensor, + total_unique_pairs: int, + ): + r_ij_full = torch.zeros( + total_unique_pairs * 2, 3, dtype=r_ij.dtype, device=r_ij.device + ) + d_ij_full = torch.zeros( + total_unique_pairs * 2, 1, dtype=d_ij.dtype, device=d_ij.device + ) + + r_ij_full[0:total_unique_pairs] = r_ij + r_ij_full[total_unique_pairs : 2 * total_unique_pairs] = -r_ij + + d_ij_full[0:total_unique_pairs] = d_ij + d_ij_full[total_unique_pairs : 2 * total_unique_pairs] = d_ij + + pairs_full = torch.zeros( + 2, total_unique_pairs * 2, dtype=torch.int64, device=i_pairs.device + ) + + pairs_full[0][0:total_unique_pairs] = i_pairs + pairs_full[1][0:total_unique_pairs] = j_pairs + pairs_full[0][total_unique_pairs : 2 * total_unique_pairs] = j_pairs + pairs_full[1][total_unique_pairs : 2 * total_unique_pairs] = i_pairs + + return pairs_full, d_ij_full, r_ij_full + + def forward(self, data: NNPInput): + """ + Prepares the input tensors for passing to the model. + + This method handles general input manipulation, such as calculating + distances and generating the pair list. It also calls the model-specific + input preparation. + + Parameters + ---------- + data : NNPInput + The input data provided by the dataset, containing atomic numbers, + positions, and other necessary information. + + Returns + ------- + PairListOutputs + Contains pair indices, distances (d_ij), and displacement vectors (r_ij) for atom pairs within the cutoff. + """ + # --------------------------- + # general input manipulation + positions = data.positions + atomic_subsystem_indices = data.atomic_subsystem_indices + + n = atomic_subsystem_indices.size(0) + + # avoid reinitializing indices if they are already set and haven't changed + if self.indices.shape[0] != n: + # Generate a range of indices from 0 to n-1 + self.indices = torch.arange(n, device=atomic_subsystem_indices.device) + + # Create a meshgrid of indices + self.i_final_pairs, self.j_final_pairs = torch.meshgrid( + self.indices, self.indices, indexing="ij" + ) + # We will only consider unique pairs; for non-unique pairs we can just appropriately copy + # the data as it will be faster than extra computations. + mask = self.i_final_pairs < self.j_final_pairs + + self.i_final_pairs = self.i_final_pairs[mask] + self.j_final_pairs = self.j_final_pairs[mask] + + # calculate r_ij and d_ij + r_ij, d_ij = self.displacement_function( + positions[self.i_final_pairs], + positions[self.j_final_pairs], + data.box_vectors, + data.is_periodic, + ) + in_cutoff = (d_ij <= self.cutoff).squeeze() + total_pairs = in_cutoff.sum() + + if self.only_unique_pairs: + # using this instead of torch.stack to ensure that if we only have a single pair + # we don't run into an issue with tensor shapes. + # note this will fail if there are no interacting pairs + + pairs = torch.zeros( + 2, total_pairs, dtype=torch.int64, device=positions.device + ) + + pairs[0] = self.i_final_pairs[in_cutoff] + pairs[1] = self.j_final_pairs[in_cutoff] + + return PairlistData( + pair_indices=pairs, + d_ij=d_ij[in_cutoff], + r_ij=r_ij[in_cutoff], + ) + + else: + pairs_full, d_ij_full, r_ij_full = self._copy_to_nonunique( + self.i_final_pairs[in_cutoff], + self.j_final_pairs[in_cutoff], + d_ij[in_cutoff], + r_ij[in_cutoff], + total_pairs, + ) + return PairlistData( + pair_indices=pairs_full, + d_ij=d_ij_full, + r_ij=r_ij_full, + ) + + +class NeighborlistVerletNsq(torch.nn.Module): + """ + Verlet neighbor list calculation for inference implemented fully in PyTorch. + + Rebuilding of the neighborlist uses an N^2 approach. Rebuilding occurs when + the maximum displacement of any particle exceeds half the skin distance. + + """ + + def __init__( + self, + cutoff: float, + skin: float, + displacement_function: OrthogonalDisplacementFunction, + only_unique_pairs: bool = False, + ): + """ + Compute neighbor lists for inference, filtering pairs based on a cutoff distance. + + Parameters + ---------- + cutoff : float + The cutoff distance for neighbor list calculations. + skin : float + The skin distance for neighbor list calculations. + displacement_function : OrthogonalDisplacementFunction + The function to calculate displacement vectors and distances between atom pairs, taking into account + the specified boundary conditions. + only_unique_pairs : bool, optional + Whether to only use unique pairs in the pair list calculation, by + default True. This should be set to True for all message passing + networks. + """ + + super().__init__() + + self.register_buffer("cutoff", torch.tensor(cutoff)) + self.skin = skin + self.half_skin = skin * 0.5 + self.cutoff_plus_skin = cutoff + skin + self.only_unique_pairs = only_unique_pairs + + self.displacement_function = displacement_function + self.indices = torch.tensor([]) + self.i_pairs = torch.tensor([]) + self.j_pairs = torch.tensor([]) + + self.positions_old = torch.tensor([]) + self.nlist_pairs = torch.tensor([]) + self.builds = 0 + self.box_vectors = torch.zeros([3, 3]) + + log.info("Initializing Verlet Neighborlist with N^2 building routine.") + + def _check_nlist( + self, positions: torch.Tensor, box_vectors: torch.Tensor, is_periodic + ): + r_ij, d_ij = self.displacement_function( + self.positions_old, positions, box_vectors, is_periodic + ) + + if torch.any(d_ij > self.half_skin): + return True + else: + return False + + def _init_pairs(self, n_particles: int, device: torch.device): + self.indices = torch.arange(n_particles, device=device) + + i_pairs, j_pairs = torch.meshgrid( + self.indices, + self.indices, + indexing="ij", + ) + + mask = i_pairs < j_pairs + self.i_pairs = i_pairs[mask] + self.j_pairs = j_pairs[mask] + + def _build_nlist( + self, positions: torch.Tensor, box_vectors: torch.Tensor, is_periodic + ): + r_ij, d_ij = self.displacement_function( + positions[self.i_pairs], positions[self.j_pairs], box_vectors, is_periodic + ) + + in_cutoff = (d_ij < self.cutoff_plus_skin).squeeze() + self.nlist_pairs = torch.stack( + [self.i_pairs[in_cutoff], self.j_pairs[in_cutoff]] + ) + self.builds += 1 + return r_ij[in_cutoff], d_ij[in_cutoff] + + def _copy_to_nonunique( + self, + pairs: torch.Tensor, + d_ij: torch.Tensor, + r_ij: torch.Tensor, + total_unique_pairs: int, + ): + # this will allow us to copy the data for unique pairs to create the non-unique pairs data + r_ij_full = torch.zeros( + total_unique_pairs * 2, 3, dtype=r_ij.dtype, device=r_ij.device + ) + d_ij_full = torch.zeros( + total_unique_pairs * 2, 1, dtype=d_ij.dtype, device=d_ij.device + ) + + r_ij_full[0:total_unique_pairs] = r_ij + + # since we are swapping the order of the pairs, the sign changes + r_ij_full[total_unique_pairs : 2 * total_unique_pairs] = -r_ij + + d_ij_full[0:total_unique_pairs] = d_ij + d_ij_full[total_unique_pairs : 2 * total_unique_pairs] = d_ij + + pairs_full = torch.zeros( + 2, total_unique_pairs * 2, dtype=torch.int64, device=pairs.device + ) + + pairs_full[0][0:total_unique_pairs] = pairs[0] + pairs_full[1][0:total_unique_pairs] = pairs[1] + pairs_full[0][total_unique_pairs : 2 * total_unique_pairs] = pairs[1] + pairs_full[1][total_unique_pairs : 2 * total_unique_pairs] = pairs[0] + + return pairs_full, d_ij_full, r_ij_full + + def forward(self, data: NNPInput): + """ + Prepares the input tensors for passing to the model. + + This method handles general input manipulation, such as calculating + distances and generating the pair list. It also calls the model-specific + input preparation. + + Parameters + ---------- + data : NNPInput + The input data provided by the dataset, containing atomic numbers, + positions, and other necessary information. + + Returns + ------- + PairListOutputs + Contains pair indices, distances (d_ij), and displacement vectors (r_ij) for atom pairs within the cutoff. + """ + # --------------------------- + # general input manipulation + positions = data.positions + atomic_subsystem_indices = data.atomic_subsystem_indices + + n = atomic_subsystem_indices.size(0) + # if the initial build we haven't yet set box vectors so set them + # this is necessary because we need to store them to know if we need to force a rebuild + # because the box vectors have changed + if self.builds == 0: + self.box_vectors = data.box_vectors + + box_changed = torch.any(self.box_vectors != data.box_vectors) + + # avoid reinitializing indices if they are already set and haven't changed + if self.indices.shape[0] != n: + self.box_vectors = data.box_vectors + self.positions_old = positions + self._init_pairs(n, positions.device) + r_ij, d_ij = self._build_nlist( + positions, data.box_vectors, data.is_periodic + ) + elif box_changed: + # if the box vectors have changed, we need to rebuild the nlist + # but do not need to regenerate the pairs + self.box_vectors = data.box_vectors + self.positions_old = positions + r_ij, d_ij = self._build_nlist( + positions, data.box_vectors, data.is_periodic + ) + elif self._check_nlist(positions, data.box_vectors, data.is_periodic): + self.positions_old = positions + r_ij, d_ij = self._build_nlist( + positions, data.box_vectors, data.is_periodic + ) + else: + r_ij, d_ij = self.displacement_function( + positions[self.nlist_pairs[0]], + positions[self.nlist_pairs[1]], + data.box_vectors, + data.is_periodic, + ) + + in_cutoff = (d_ij <= self.cutoff).squeeze() + total_pairs = in_cutoff.sum() + + if self.only_unique_pairs: + # using this instead of torch.stack to ensure that if we only have a single pair + # we don't run into an issue with shapes. + + pairs = torch.zeros( + 2, total_pairs, dtype=torch.int64, device=positions.device + ) + + pairs[0] = self.nlist_pairs[0][in_cutoff] + pairs[1] = self.nlist_pairs[1][in_cutoff] + + return PairlistData( + pair_indices=pairs, + d_ij=d_ij[in_cutoff], + r_ij=r_ij[in_cutoff], + ) + + else: + pairs_full, d_ij_full, r_ij_full = self._copy_to_nonunique( + self.nlist_pairs[:, in_cutoff], + d_ij[in_cutoff], + r_ij[in_cutoff], + total_pairs, + ) + return PairlistData( + pair_indices=pairs_full, + d_ij=d_ij_full, + r_ij=r_ij_full, + ) + + +class NeighborListForTraining(torch.nn.Module): + def __init__(self, cutoff: float, only_unique_pairs: bool = False): + """ + Calculating the interacting pairs. This is primarily intended for use during training, + as this will utilize the pre-computed pair list from the dataset + + Parameters + ---------- + cutoff : float + The cutoff distance for neighbor list calculations. + only_unique_pairs : bool, optional + If True, only unique pairs are returned (default is False). + """ + + super().__init__() + + self.only_unique_pairs = only_unique_pairs + self.pairlist = Pairlist(only_unique_pairs) + self.register_buffer("cutoff", torch.tensor(cutoff)) + + def calculate_r_ij( + self, pair_indices: torch.Tensor, positions: torch.Tensor + ) -> torch.Tensor: + """Compute displacement vectors between atom pairs. + + Parameters + ---------- + pair_indices : torch.Tensor + Atom indices for pairs of atoms. Shape: [2, n_pairs]. + positions : torch.Tensor + Atom positions. Shape: [atoms, 3]. + + Returns + ------- + torch.Tensor + Displacement vectors between atom pairs. Shape: [n_pairs, 3]. + """ + # Select the pairs of atom coordinates from the positions + selected_positions = positions.index_select(0, pair_indices.view(-1)).view( + 2, -1, 3 + ) + return selected_positions[1] - selected_positions[0] + + def calculate_d_ij(self, r_ij: torch.Tensor) -> torch.Tensor: + """ + ompute Euclidean distances between atoms in each pair. + + Parameters + ---------- + r_ij : torch.Tensor + Displacement vectors between atoms in a pair. Shape: [n_pairs, 3]. + + Returns + ------- + torch.Tensor + Euclidean distances. Shape: [n_pairs, 1]. + """ + return r_ij.norm(dim=1).unsqueeze(1) + + def _calculate_interacting_pairs( + self, + positions: torch.Tensor, + atomic_subsystem_indices: torch.Tensor, + pair_indices: torch.Tensor, + ) -> PairlistData: + """ + Compute the neighbor list considering a cutoff distance. + + Parameters + ---------- + positions : torch.Tensor + Atom positions. Shape: [nr_systems, nr_atoms, 3]. + atomic_subsystem_indices : torch.Tensor + Indices identifying atoms in subsystems. Shape: [nr_atoms]. + pair_indices : torch.Tensor + Precomputed pair indices. + + Returns + ------- + PairListOutputs + A dataclass containing 'pair_indices', 'd_ij' (distances), and 'r_ij' (displacement vectors). + """ + + r_ij = self.calculate_r_ij(pair_indices, positions) + d_ij = self.calculate_d_ij(r_ij) + + in_cutoff = (d_ij <= self.cutoff).squeeze() + # Get the atom indices within the cutoff + pair_indices_within_cutoff = pair_indices[:, in_cutoff] + + return PairlistData( + pair_indices=pair_indices_within_cutoff, + d_ij=d_ij[in_cutoff], + r_ij=r_ij[in_cutoff], + ) + + def forward(self, data: Union[NNPInput, NamedTuple]) -> PairlistData: + """ + Compute the pair list, distances, and displacement vectors for the given + input data. + + Parameters + ---------- + data : Union[NNPInput, NamedTuple] + Input data containing atomic numbers, positions, and subsystem + indices. + + Returns + ------- + PairlistData + A namedtuple containing the pair indices, distances, and + displacement vectors. + """ + # --------------------------- + # general input manipulation + positions = data.positions + atomic_subsystem_indices = data.atomic_subsystem_indices + # calculate pairlist if it is not provided + + if data.pair_list is None or data.pair_list.shape[0] == 0: + # note, we set the flag for unique pairs when instantiated in the constructor + # and thus this call will return unique pairs if requested. + pair_list = self.pairlist.enumerate_all_pairs(atomic_subsystem_indices) + + else: + pair_list = data.pair_list + + # when we precompute the pairlist, we included all pairs, including non-unique + # since we do this before we know about which potential we are using + # and whether we require unique pairs or not + # thus, if the pairlist is provided we need to remove redundant pairs if requested + if self.only_unique_pairs: + i_indices = pair_list[0] + j_indices = pair_list[1] + unique_pairs_mask = i_indices < j_indices + i_final_pairs = i_indices[unique_pairs_mask] + j_final_pairs = j_indices[unique_pairs_mask] + pair_list = torch.stack((i_final_pairs, j_final_pairs)) + + pairlist_output = self._calculate_interacting_pairs( + positions=positions, + atomic_subsystem_indices=atomic_subsystem_indices, + pair_indices=pair_list.to(torch.int64), + ) + + return pairlist_output + + +# diff --git a/modelforge/potential/painn.py b/modelforge/potential/painn.py index 48b3a2d6..0525924b 100644 --- a/modelforge/potential/painn.py +++ b/modelforge/potential/painn.py @@ -1,253 +1,271 @@ -from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Union +""" +PaiNN - polarizable interaction neural network +""" + +from typing import Dict, List, Tuple, Type, Union import torch import torch.nn as nn -import torch.nn.functional as F from loguru import logger as log from openff.units import unit -from .models import InputPreparation, NNPInput, BaseNetwork, CoreNetwork - -from .utils import Dense - -if TYPE_CHECKING: - from .models import PairListOutputs - from modelforge.dataset.dataset import NNPInput - -from dataclasses import dataclass, field - -from modelforge.potential.utils import NeuralNetworkData - - -@dataclass -class PaiNNNeuralNetworkData(NeuralNetworkData): - """ - A dataclass designed to structure the inputs for PaiNN neural network potentials, ensuring - an efficient and structured representation of atomic systems for energy computation and - property prediction within the PaiNN framework. - - Attributes - ---------- - atomic_numbers : torch.Tensor - Atomic numbers for each atom in the system(s). Shape: [num_atoms]. - positions : torch.Tensor - XYZ coordinates of each atom. Shape: [num_atoms, 3]. - atomic_subsystem_indices : torch.Tensor - Maps each atom to its respective subsystem or molecule, useful for systems with multiple - molecules. Shape: [num_atoms]. - total_charge : torch.Tensor - Total charge of each system or molecule. Shape: [num_systems]. - pair_indices : torch.Tensor - Indicates indices of atom pairs, essential for computing pairwise features. Shape: [2, num_pairs]. - d_ij : torch.Tensor - Distances between each pair of atoms, derived from `pair_indices`. Shape: [num_pairs, 1]. - r_ij : torch.Tensor - Displacement vectors between atom pairs, providing directional context. Shape: [num_pairs, 3]. - number_of_atoms : int - Total number of atoms in the batch, facilitating batch-wise operations. - atomic_embedding : torch.Tensor - Embeddings or features for each atom, potentially derived from atomic numbers or learned. Shape: [num_atoms, embedding_dim]. - - Notes - ----- - The `PaiNNNeuralNetworkInput` dataclass encapsulates essential inputs required by the PaiNN neural network - model for accurately predicting system energies and properties. It includes atomic positions, atomic types, - and connectivity information, crucial for a detailed representation of atomistic systems. - - Examples - -------- - >>> painn_input = PaiNNNeuralNetworkData( - ... atomic_numbers=torch.tensor([1, 6, 6, 8]), - ... positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]]), - ... atomic_subsystem_indices=torch.tensor([0, 0, 0, 0]), - ... total_charge=torch.tensor([0.0]), - ... pair_indices=torch.tensor([[0, 1], [0, 2], [1, 2]]).T, - ... d_ij=torch.tensor([[1.0], [1.0], [1.0]]), - ... r_ij=torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]), - ... number_of_atoms=4, - ... atomic_embedding=torch.randn(4, 5) # Example atomic embeddings - ... ) - """ - - atomic_embedding: torch.Tensor - - -class PaiNNCore(CoreNetwork): - """PaiNN - polarizable interaction neural network - - References: - Equivariant message passing for the prediction of tensorial properties and molecular spectra. - ICML 2021, http://proceedings.mlr.press/v139/schutt21a.html - - """ + +from modelforge.utils.prop import NNPInput +from modelforge.potential.neighbors import PairlistData +from .utils import DenseWithCustomDist + + +class PaiNNCore(torch.nn.Module): def __init__( self, - max_Z: int = 100, - number_of_atom_features: int = 64, - number_of_radial_basis_functions: int = 16, - cutoff: unit.Quantity = 5 * unit.angstrom, - number_of_interaction_modules: int = 2, - shared_interactions: bool = False, - shared_filters: bool = False, + featurization: Dict[str, Dict[str, int]], + number_of_radial_basis_functions: int, + maximum_interaction_radius: float, + number_of_interaction_modules: int, + shared_interactions: bool, + shared_filters: bool, + activation_function_parameter: Dict[str, str], + predicted_properties: List[str], + predicted_dim: List[int], epsilon: float = 1e-8, + potential_seed: int = -1, ): - log.debug("Initializing PaiNN model.") - super().__init__() + """ + Core PaiNN architecture for modeling polarizable molecular interactions. - self.number_of_interaction_modules = number_of_interaction_modules - self.number_of_atom_features = number_of_atom_features - self.shared_filters = shared_filters + Parameters + ---------- + featurization : Dict[str, Dict[str, int]] + Configuration for atom featurization, including number of features + per atom. + number_of_radial_basis_functions : int + Number of radial basis functions for the PaiNN representation. + maximum_interaction_radius : float + Maximum interaction radius for atom pairs. + number_of_interaction_modules : int + Number of interaction modules to apply. + shared_interactions : bool + Whether to share weights across all interaction modules. + shared_filters : bool + Whether to share filters across blocks. + activation_function_parameter : Dict[str, str] + Dictionary containing the activation function to use. + predicted_properties : List[str] + List of properties to predict. + predicted_dim : List[int] + List of dimensions for each predicted property. + epsilon : float, optional + Small constant for numerical stability (default is 1e-8). + potential_seed : int, optional + Seed for random number generation (default is -1). + """ - # embedding - from modelforge.potential.utils import Embedding + from modelforge.utils.misc import seed_random_number - self.embedding_module = Embedding(max_Z, number_of_atom_features) + if potential_seed != -1: + seed_random_number(potential_seed) + super().__init__() + log.debug("Initializing the PaiNN architecture.") + self.activation_function = activation_function_parameter["activation_function"] + + self.number_of_interaction_modules = number_of_interaction_modules + + # Featurize the atomic input + number_of_per_atom_features = int( + featurization["atomic_number"]["number_of_per_atom_features"] + ) # initialize representation block self.representation_module = PaiNNRepresentation( - cutoff, + maximum_interaction_radius, number_of_radial_basis_functions, number_of_interaction_modules, - number_of_atom_features, + number_of_per_atom_features, shared_filters, + featurization_config=featurization, ) # initialize the interaction and mixing networks - self.interaction_modules = nn.ModuleList( - PaiNNInteraction(number_of_atom_features, activation=F.silu) - for _ in range(number_of_interaction_modules) - ) - self.mixing_modules = nn.ModuleList( - PaiNNMixing(number_of_atom_features, activation=F.silu, epsilon=epsilon) - for _ in range(number_of_interaction_modules) - ) - - self.energy_layer = nn.Sequential( - Dense( - number_of_atom_features, number_of_atom_features, activation=nn.ReLU() - ), - Dense( - number_of_atom_features, - 1, - ), - ) + if shared_interactions: + self.message_function = nn.ModuleList( + [ + Message( + number_of_per_atom_features, + activation_function=self.activation_function, + ) + ] + * number_of_interaction_modules + ) + else: + self.message_function = nn.ModuleList( + [ + Message( + number_of_per_atom_features, + activation_function=self.activation_function, + ) + for _ in range(number_of_interaction_modules) + ] + ) - def _model_specific_input_preparation( - self, data: "NNPInput", pairlist_output: "PairListOutputs" - ) -> PaiNNNeuralNetworkData: - # Perform atomic embedding - - number_of_atoms = data.atomic_numbers.shape[0] - - nnp_input = PaiNNNeuralNetworkData( - pair_indices=pairlist_output.pair_indices, - d_ij=pairlist_output.d_ij, - r_ij=pairlist_output.r_ij, - number_of_atoms=number_of_atoms, - positions=data.positions, - atomic_numbers=data.atomic_numbers, - atomic_subsystem_indices=data.atomic_subsystem_indices, - total_charge=data.total_charge, - atomic_embedding=self.embedding_module( - data.atomic_numbers - ), # atom embedding + self.update_function = nn.ModuleList( + [ + Update( + number_of_per_atom_features, + activation_function=self.activation_function, + epsilon=epsilon, + ) + for _ in range(number_of_interaction_modules) + ] ) - return nnp_input + # Initialize output layers based on configuration + self.output_layers = nn.ModuleDict() + for property, dim in zip(predicted_properties, predicted_dim): + self.output_layers[property] = nn.Sequential( + DenseWithCustomDist( + number_of_per_atom_features, + number_of_per_atom_features, + activation_function=self.activation_function, + ), + DenseWithCustomDist( + number_of_per_atom_features, + int(dim), + ), + ) def compute_properties( - self, - data: PaiNNNeuralNetworkData, - ): + self, data: NNPInput, pairlist_output: PairlistData + ) -> Dict[str, torch.Tensor]: """ - Compute atomic representations/embeddings. + Compute atomic representations and embeddings using PaiNN. Parameters ---------- - data : PaiNNNeuralNetworkInput(NamedTuple) - atomic_embedding : torch.Tensor - Tensor containing atomic number embeddings. + data : NNPInput + The input data containing atomic numbers, positions, etc. + pairlist_output : PairlistData + The output from the pairlist module. Returns ------- Dict[str, torch.Tensor] - Dictionary containing scalar and vector representations. + Dictionary containing scalar and vector atomic representations. """ - - # initialize filters, q and mu - transformed_input = self.representation_module(data) + # Compute filters, scalar features (q), and vector features (mu) + transformed_input = self.representation_module(data, pairlist_output) filter_list = transformed_input["filters"] - q = transformed_input["q"] - mu = transformed_input["mu"] + per_atom_scalar_feature = transformed_input["per_atom_scalar_feature"] + per_atom_vector_feature = transformed_input["per_atom_vector_feature"] dir_ij = transformed_input["dir_ij"] + # Apply interaction and mixing modules for i, (interaction_mod, mixing_mod) in enumerate( - zip(self.interaction_modules, self.mixing_modules) + zip(self.message_function, self.update_function) ): - q, mu = interaction_mod( - q, - mu, + per_atom_scalar_feature, per_atom_vector_feature = interaction_mod( + per_atom_scalar_feature, + per_atom_vector_feature, filter_list[i], dir_ij, - data.pair_indices, + pairlist_output.pair_indices, + ) + per_atom_scalar_feature, per_atom_vector_feature = mixing_mod( + per_atom_scalar_feature, per_atom_vector_feature ) - q, mu = mixing_mod(q, mu) - - # Use squeeze to remove dimensions of size 1 - q = q.squeeze(dim=1) - E_i = self.energy_layer(q).squeeze(1) return { - "per_atom_energy": E_i, - "mu": mu, - "q": q, + "per_atom_scalar_representation": per_atom_scalar_feature.squeeze(1), + "per_atom_vector_representation": per_atom_vector_feature, "atomic_subsystem_indices": data.atomic_subsystem_indices, + "atomic_numbers": data.atomic_numbers, } + def forward( + self, data: NNPInput, pairlist_output: PairlistData + ) -> Dict[str, torch.Tensor]: + """ + Forward pass of the PaiNN model. -from openff.units import unit + Parameters + ---------- + data : NNPInput + Input data including atomic numbers, positions, etc. + pairlist_output : PairlistData + Pair indices and distances from the pairlist module. + + Returns + ------- + Dict[str, torch.Tensor] + The predicted properties from the forward pass. + """ + # Compute properties using the core PaiNN modules + results = self.compute_properties(data, pairlist_output) + # Apply output layers to the atomic embedding + atomic_embedding = results["per_atom_scalar_representation"] + for output_name, output_layer in self.output_layers.items(): + results[output_name] = output_layer(atomic_embedding) + + return results class PaiNNRepresentation(nn.Module): - """PaiNN representation module""" def __init__( self, - cutoff: unit.Quantity = 5 * unit.angstrom, - number_of_radial_basis_functions: int = 16, - nr_interaction_blocks: int = 3, - nr_atom_basis: int = 8, - shared_filters: bool = False, + maximum_interaction_radius: float, + number_of_radial_basis_functions: int, + nr_interaction_blocks: int, + nr_atom_basis: int, + shared_filters: bool, + featurization_config: Dict[str, Union[List[str], int]], ): - super().__init__() + """ + PaiNN representation module for generating scalar and vector atomic embeddings. - # cutoff - from modelforge.potential import CosineCutoff + Parameters + ---------- + maximum_interaction_radius : float + Maximum interaction radius for atomic pairs in nanometer. + number_of_radial_basis_functions : int + Number of radial basis functions. + nr_interaction_blocks : int + Number of interaction blocks. + nr_atom_basis : int + Number of features to describe atomic environments. + shared_filters : bool + Whether to share filters across blocks. + featurization_config : Dict[str, Union[List[str], int]] + Configuration for atom featurization. + """ + from modelforge.potential import CosineAttenuationFunction, FeaturizeInput - self.cutoff_module = CosineCutoff(cutoff) + from .representation import SchnetRadialBasisFunction - # radial symmetry function - from .utils import SchnetRadialBasisFunction + super().__init__() + + self.featurize_input = FeaturizeInput(featurization_config) + + # Initialize the cutoff function and radial symmetry functions + self.cutoff_module = CosineAttenuationFunction(maximum_interaction_radius) self.radial_symmetry_function_module = SchnetRadialBasisFunction( number_of_radial_basis_functions=number_of_radial_basis_functions, - max_distance=cutoff, + max_distance=maximum_interaction_radius, dtype=torch.float32, ) # initialize the filter network if shared_filters: - filter_net = Dense( - number_of_radial_basis_functions, - 3 * nr_atom_basis, + filter_net = DenseWithCustomDist( + in_features=number_of_radial_basis_functions, + out_features=3 * nr_atom_basis, ) else: - filter_net = Dense( - number_of_radial_basis_functions, - nr_interaction_blocks * nr_atom_basis * 3, - activation=None, + filter_net = DenseWithCustomDist( + in_features=number_of_radial_basis_functions, + out_features=nr_interaction_blocks * nr_atom_basis * 3, ) self.filter_net = filter_net @@ -256,108 +274,113 @@ def __init__( self.nr_interaction_blocks = nr_interaction_blocks self.nr_atom_basis = nr_atom_basis - def forward(self, data: PaiNNNeuralNetworkData): + def forward( + self, data: NNPInput, pairlist_output: PairlistData + ) -> Dict[str, torch.Tensor]: """ - Transforms the input data for the PAInn potential model. + Generate atomic embeddings and filters for PaiNN. Parameters ---------- - inputs (Dict[str, torch.Tensor]): A dictionary containing the input tensors. - - "d_ij" (torch.Tensor): Pairwise distances between atoms. Shape: (n_pairs, 1, 1). - - "r_ij" (torch.Tensor): Displacement vector between atoms. Shape: (n_pairs, 1, 3). - - "atomic_embedding" (torch.Tensor): Embeddings of atomic numbers. Shape: (n_atoms, 1, embedding_dim). + data : NNPInput + The input data containing atomic numbers, positions, etc. + pairlist_output : PairlistData + The output from the pairlist module, containing pair indices and distances. Returns - ---------- - Dict[str, torch.Tensor]: + ------- + Dict[str, torch.Tensor] A dictionary containing the transformed input tensors. - - "mu" (torch.Tensor) - Zero-initialized tensor for atom features. Shape: (n_atoms, 3, nr_atom_basis). - - "dir_ij" (torch.Tensor) - Direction vectors between atoms. Shape: (n_pairs, 1, distance). - - "q" (torch.Tensor): Reshaped atomic number embeddings. Shape: (n_atoms, 1, embedding_dim). """ - # compute normalized pairwise distances - d_ij = data.d_ij - r_ij = data.r_ij - # normalize the direction vectors - dir_ij = r_ij / d_ij + d_ij = pairlist_output.d_ij + dir_ij = pairlist_output.r_ij / d_ij + # featurize pairwise distances using radial basis functions (RBF) f_ij = self.radial_symmetry_function_module(d_ij) - - fcut = self.cutoff_module(d_ij) - - filters = self.filter_net(f_ij) * fcut - + # Apply the filter network and cutoff function + filters = torch.mul(self.filter_net(f_ij), self.cutoff_module(d_ij)) + # depending on whether we share filters or not filters have different + # shape at dim=1 (dim=0 is always the number of atom pairs) if we share + # filters, we copy the filters and use the same filters for all blocks if self.shared_filters: - filter_list = [filters] * self.nr_interaction_blocks + filter_list = torch.stack([filters] * self.nr_interaction_blocks, dim=0) + # otherwise we index into subset of the calculated filters and provide + # each block with its own set of filters else: - filter_list = torch.split(filters, 3 * self.nr_atom_basis, dim=-1) - - # generate q and mu - atomic_embedding = data.atomic_embedding - q = atomic_embedding[:, None] # nr_of_atoms, 1, nr_atom_basis - q_shape = q.shape - mu = torch.zeros( - (q_shape[0], 3, q_shape[2]), device=q.device, dtype=q.dtype - ) # nr_of_atoms, 3, nr_atom_basis + filter_list = torch.stack( + torch.split(filters, 3 * self.nr_atom_basis, dim=-1), dim=0 + ) - return {"filters": filter_list, "dir_ij": dir_ij, "q": q, "mu": mu} + # Initialize scalar and vector features + per_atom_scalar_feature = self.featurize_input(data).unsqueeze( + 1 + ) # nr_of_atoms, 1, nr_atom_basis + atomic_embedding_shape = per_atom_scalar_feature.shape + per_atom_vector_feature = torch.zeros( + (atomic_embedding_shape[0], 3, atomic_embedding_shape[2]), + device=per_atom_scalar_feature.device, + dtype=per_atom_scalar_feature.dtype, + ) # nr_of_atoms, 3, nr_atom_basis + return { + "filters": filter_list, + "dir_ij": dir_ij, + "per_atom_scalar_feature": per_atom_scalar_feature, + "per_atom_vector_feature": per_atom_vector_feature, + } -class PaiNNInteraction(nn.Module): - """ - PaiNN Interaction Block for Modeling Equivariant Interactions of Atomistic Systems. - """ +class Message(nn.Module): - def __init__(self, nr_atom_basis: int, activation: Callable): + def __init__( + self, + nr_atom_basis: int, + activation_function: torch.nn.Module, + ): """ - Parameters - ---------- - nr_atom_basis : int - Number of features to describe atomic environments. - activation : Callable - Activation function to use. + PaiNN message block for modeling scalar and vector interactions between atoms. - Attributes + Parameters ---------- nr_atom_basis : int Number of features to describe atomic environments. - interatomic_net : nn.Sequential - Neural network for interatomic interactions. + activation_function : Type[torch.nn.Module] + Activation function to use in the interaction block. """ super().__init__() self.nr_atom_basis = nr_atom_basis - # Initialize the intra-atomic neural network + # Initialize the interatomic network self.interatomic_net = nn.Sequential( - Dense(nr_atom_basis, nr_atom_basis, activation=activation), - Dense(nr_atom_basis, 3 * nr_atom_basis, activation=None), + DenseWithCustomDist( + nr_atom_basis, nr_atom_basis, activation_function=activation_function + ), + DenseWithCustomDist(nr_atom_basis, 3 * nr_atom_basis), ) def forward( self, - q: torch.Tensor, - mu: torch.Tensor, + per_atom_scalar_representation: torch.Tensor, + per_atom_vector_representation: torch.Tensor, W_ij: torch.Tensor, dir_ij: torch.Tensor, pairlist: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute interaction output. + """ + Forward pass of the message block. Parameters ---------- - q : torch.Tensor - Scalar input values of shape [nr_of_atoms, 1, nr_atom_basis]. - mu : torch.Tensor - Vector input values of shape [nr_of_atoms, 3, nr_atom_basis]. + per_atom_scalar_representation : torch.Tensor + Scalar input values (Shape: [nr_atoms, 1, nr_atom_basis]). + per_atom_vector_representation : torch.Tensor + Vector input values (Shape: [nr_atoms, 3, nr_atom_basis]). W_ij : torch.Tensor - Filter of shape [nr_of_pairs, 1, n_interactions]. + Interaction filters (Shape: [nr_pairs, 1, nr_interactions]). dir_ij : torch.Tensor - Directional vector between atoms i and j. - pairlist : torch.Tensor, shape (2, n_pairs) + Direction vectors between atoms i and j. + pairlist : torch.Tensor Returns ------- @@ -367,188 +390,118 @@ def forward( # perform the scalar operations (same as in SchNet) idx_i, idx_j = pairlist[0], pairlist[1] - x_per_atom = self.interatomic_net(q) # per atom + # Compute scalar interactions (q) + transformed_per_atom_scalar_representation = self.interatomic_net( + per_atom_scalar_representation + ) # per atom + s_j = transformed_per_atom_scalar_representation[idx_j] # per pair + weighted_s_j = W_ij.unsqueeze(1) * s_j # per_pair - x_j = x_per_atom[idx_j] # per pair - x_per_pair = W_ij.unsqueeze(1) * x_j # per_pair - - # split the output into dq, dmuR, dmumu to excchange information between the scalar and vector outputs - dq_per_pair, dmuR, dmumu = torch.split(x_per_pair, self.nr_atom_basis, dim=-1) + # split the output into 3x per_pair_ds to exchange information between the scalar and vector outputs + per_pair_ds1, per_pair_ds2, per_pair_ds3 = torch.split( + weighted_s_j, self.nr_atom_basis, dim=-1 + ) - # for scalar output only dq is used - # scatter the dq to the atoms (reducton from pairs to atoms) - dq_per_atom = torch.zeros_like(q) # Shape: (nr_of_pairs, 1, nr_atom_basis) + # Update scalar feature + ds_i = torch.zeros_like(per_atom_scalar_representation) # Expand idx_i to match the shape of dq for scatter_add operation - expanded_idx_i = idx_i.unsqueeze(-1).expand(-1, dq_per_pair.size(2)) - - dq_per_atom.scatter_add_(0, expanded_idx_i.unsqueeze(1), dq_per_pair) - - q = q + dq_per_atom + expanded_idx_i = idx_i.unsqueeze(-1).expand(-1, per_pair_ds1.size(2)) + ds_i.scatter_add_(0, expanded_idx_i.unsqueeze(1), per_pair_ds1) + per_atom_scalar_representation = per_atom_scalar_representation + ds_i # ----------------- vector output ----------------- - # for vector output dmuR and dmumu are used - # dmuR: (nr_of_pairs, 1, nr_atom_basis) - # dir_ij: (nr_of_pairs, 3) - # dmumu: (nr_of_pairs, 1, nr_atom_basis) - # muj: (nr_of_pairs, 1, nr_atom_basis) - # idx_i: (nr_of_pairs) - # mu: (nr_of_atoms, 3, nr_atom_basis) - - muj = mu[idx_j] # shape (nr_of_pairs, 1, nr_atom_basis) - - dmu_per_pair = ( - dmuR * dir_ij.unsqueeze(-1) + dmumu * muj - ) # shape (nr_of_pairs, 3, nr_atom_basis) + # Compute vector interactions (dv_i) + v_j = per_atom_vector_representation[idx_j] + dmu_per_pair = per_pair_ds2 * dir_ij.unsqueeze(-1) + per_pair_ds3 * v_j # Create a tensor to store the result, matching the size of `mu` - dmu_per_atom = torch.zeros_like(mu) # Shape: (nr_of_atoms, 3, nr_atom_basis) - + dv_i = torch.zeros_like( + per_atom_vector_representation + ) # Shape: (nr_of_atoms, 3, nr_atom_basis) # Expand idx_i to match the shape of dmu for scatter_add operation expanded_idx_i = ( - idx_i.unsqueeze(-1) - .unsqueeze(-1) - .expand(-1, dmu_per_atom.size(1), dmu_per_atom.size(2)) + idx_i.unsqueeze(-1).unsqueeze(-1).expand(-1, dv_i.size(1), dv_i.size(2)) ) - # Perform scatter_add_ operation - dmu_per_atom.scatter_add_(0, expanded_idx_i, dmu_per_pair) + dv_i.scatter_add_(0, expanded_idx_i, dmu_per_pair) - mu = mu + dmu_per_atom + per_atom_vector_representation = per_atom_vector_representation + dv_i - return q, mu + return per_atom_scalar_representation, per_atom_vector_representation -class PaiNNMixing(nn.Module): - r"""PaiNN interaction block for mixing on atom features.""" +class Update(nn.Module): - def __init__(self, nr_atom_basis: int, activation: Callable, epsilon: float = 1e-8): + def __init__( + self, + nr_atom_basis: int, + activation_function: torch.nn.Module, + epsilon: float = 1e-8, + ): """ + PaiNN update block. + Parameters ---------- nr_atom_basis : int Number of features to describe atomic environments. - activation : Callable + activation_function : torch.nn.Module Activation function to use. epsilon : float, optional - Stability constant added in norm to prevent numerical instabilities. Default is 1e-8. - - Attributes - ---------- - nr_atom_basis : int - Number of features to describe atomic environments. - intra_atomic_net : nn.Sequential - Neural network for intra-atomic interactions. - mu_channel_mix : nn.Sequential - Neural network for mixing mu channels. - epsilon : float - Stability constant for numerical stability. + Stability constant added to prevent numerical instabilities (default is 1e-8). """ super().__init__() self.nr_atom_basis = nr_atom_basis # initialize the intra-atomic neural network self.intra_atomic_net = nn.Sequential( - Dense(2 * nr_atom_basis, nr_atom_basis, activation=activation), - Dense(nr_atom_basis, 3 * nr_atom_basis, activation=None), + DenseWithCustomDist( + 2 * nr_atom_basis, + nr_atom_basis, + activation_function=activation_function, + ), + DenseWithCustomDist(nr_atom_basis, 3 * nr_atom_basis), + ) + # Initialize the channel mixing network for mu + self.linear_transformation = DenseWithCustomDist( + nr_atom_basis, 2 * nr_atom_basis, bias=False ) - # initialize the mu channel mixing network - self.mu_channel_mix = Dense(nr_atom_basis, 2 * nr_atom_basis, bias=False) self.epsilon = epsilon - def forward(self, q: torch.Tensor, mu: torch.Tensor): + def forward( + self, scalar_message: torch.Tensor, vector_message: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: """ - compute intratomic mixing + Forward pass through node update. Parameters ---------- - q : torch.Tensor + scalar_message : torch.Tensor Scalar input values. - mu : torch.Tensor + vector_message : torch.Tensor Vector input values. Returns ------- Tuple[torch.Tensor, torch.Tensor] - Updated scalar and vector representations (q, mu). + Updated scalar and vector representations . """ - mu_mix = self.mu_channel_mix(mu) - mu_V, mu_W = torch.split(mu_mix, self.nr_atom_basis, dim=-1) - mu_Vn = torch.sqrt(torch.sum(mu_V**2, dim=-2, keepdim=True) + self.epsilon) - - ctx = torch.cat([q, mu_Vn], dim=-1) - x = self.intra_atomic_net(ctx) - - dq_intra, dmu_intra, dqmu_intra = torch.split(x, self.nr_atom_basis, dim=-1) - dmu_intra = dmu_intra * mu_W + vector_meassge_transformed = self.linear_transformation(vector_message) - dqmu_intra = dqmu_intra * torch.sum(mu_V * mu_W, dim=1, keepdim=True) + v_V, v_U = torch.split(vector_meassge_transformed, self.nr_atom_basis, dim=-1) - q = q + dq_intra + dqmu_intra - mu = mu + dmu_intra - return q, mu + L2_norm_v_V = torch.sqrt(torch.sum(v_V**2, dim=-2, keepdim=True) + self.epsilon) + ctx = torch.cat([scalar_message, L2_norm_v_V], dim=-1) + transformed_scalar_message = self.intra_atomic_net(ctx) -from .models import InputPreparation, NNPInput, BaseNetwork -from typing import List - - -class PaiNN(BaseNetwork): - def __init__( - self, - max_Z: int, - number_of_atom_features: int, - number_of_radial_basis_functions: int, - cutoff: Union[unit.Quantity, str], - number_of_interaction_modules: int, - shared_interactions: bool, - shared_filters: bool, - postprocessing_parameter: Dict[str, Dict[str, bool]], - dataset_statistic: Optional[Dict[str, float]] = None, - epsilon: float = 1e-8, - ) -> None: - - from modelforge.utils.units import _convert - - self.only_unique_pairs = False # NOTE: for pairlist - - super().__init__( - dataset_statistic=dataset_statistic, - postprocessing_parameter=postprocessing_parameter, - cutoff=_convert(cutoff), + a_ss, a_vv, a_sv = torch.split( + transformed_scalar_message, self.nr_atom_basis, dim=-1 ) - self.core_module = PaiNNCore( - max_Z=max_Z, - number_of_atom_features=number_of_atom_features, - number_of_radial_basis_functions=number_of_radial_basis_functions, - cutoff=_convert(cutoff), - number_of_interaction_modules=number_of_interaction_modules, - shared_interactions=shared_interactions, - shared_filters=shared_filters, - epsilon=epsilon, - ) - - def _config_prior(self): - log.info("Configuring PaiNN model hyperparameter prior distribution") - from modelforge.utils.io import import_ - - tune = import_("ray").tune - # from ray import tune - - from modelforge.potential.utils import shared_config_prior - - prior = { - "number_of_atom_features": tune.randint(2, 256), - "number_of_interaction_modules": tune.randint(1, 5), - "cutoff": tune.uniform(5, 10), - "number_of_radial_basis_functions": tune.randint(8, 32), - "shared_filters": tune.choice([True, False]), - "shared_interactions": tune.choice([True, False]), - } - prior.update(shared_config_prior()) - return prior + a_vv = a_vv * v_U + a_sv = a_sv * torch.sum(v_V * v_U, dim=1, keepdim=True) - def combine_per_atom_properties( - self, values: Dict[str, torch.Tensor] - ) -> torch.Tensor: - return values + scalar_message = scalar_message + a_ss + a_sv + vector_message = vector_message + a_vv + return scalar_message, vector_message diff --git a/modelforge/potential/parameters.py b/modelforge/potential/parameters.py new file mode 100644 index 00000000..cb94d52a --- /dev/null +++ b/modelforge/potential/parameters.py @@ -0,0 +1,336 @@ +""" +This module contains pydantic models for storing the parameters of the potentials. +""" + +from enum import Enum +from typing import List, Optional, Type, Union + +import torch +from pydantic import ( + BaseModel, + ConfigDict, + Field, + computed_field, + field_validator, + model_validator, +) + +from modelforge.utils.units import _convert_str_or_unit_to_unit_length + + +class CaseInsensitiveEnum(str, Enum): + @classmethod + def _missing_(cls, value): + for member in cls: + if member.value.lower() == value.lower(): + return member + return super()._missing_(value) + + +# To avoid having to set config parameters for each class, +# we will just create a parent class for all the parameters classes. +class ParametersBase(BaseModel): + model_config = ConfigDict( + use_enum_values=True, + arbitrary_types_allowed=True, + validate_assignment=True, + extra="forbid", + ) + + +# for the activation functions we have defined alpha and negative slope are the +# two parameters that are possible +class ActivationFunctionParamsAlpha(BaseModel): + alpha: Optional[float] = None + + +class ActivationFunctionParamsNegativeSlope(BaseModel): + negative_slope: Optional[float] = None + + +class AtomicNumber(BaseModel): + maximum_atomic_number: int = 101 + number_of_per_atom_features: int = 32 + + +class Featurization(BaseModel): + properties_to_featurize: List[str] + atomic_number: AtomicNumber = Field(default_factory=AtomicNumber) + + +class ActivationFunctionName(CaseInsensitiveEnum): + ReLU = "ReLU" + CeLU = "CeLU" + GeLU = "GeLU" + Sigmoid = "Sigmoid" + Softmax = "Softmax" + ShiftedSoftplus = "ShiftedSoftplus" + SiLU = "SiLU" + Tanh = "Tanh" + LeakyReLU = "LeakyReLU" + ELU = "ELU" + + +# this enum will tell us if we need to pass additional parameters to the activation function +class ActivationFunctionParamsEnum(CaseInsensitiveEnum): + ReLU = "None" + CeLU = ActivationFunctionParamsAlpha + GeLU = "None" + Sigmoid = "None" + Softmax = "None" + ShiftedSoftplus = "None" + SiLU = "None" + Tanh = "None" + LeakyReLU = ActivationFunctionParamsNegativeSlope + ELU = ActivationFunctionParamsAlpha + + +class CoreParameterBase(ParametersBase): + # Ensure that both lists (properties and sizes) have the same length + @model_validator(mode="after") + def validate_predicted_properties(self): + if len(self.predicted_properties) != len(self.predicted_dim): + raise ValueError( + "The length of 'predicted_properties' and 'predicted_dim' must be the same." + ) + return self + + +class ActivationFunctionConfig(ParametersBase): + + activation_function_name: ActivationFunctionName + activation_function_arguments: Optional[ + Union[ActivationFunctionParamsAlpha, ActivationFunctionParamsNegativeSlope] + ] = None + + @model_validator(mode="after") + def validate_activation_function_arguments(self) -> "ActivationFunctionConfig": + if ActivationFunctionParamsEnum[self.activation_function_name].value != "None": + if self.activation_function_arguments is None: + raise ValueError( + f"Activation function {self.activation_function_name} requires additional arguments." + ) + else: + if self.activation_function_arguments is not None: + raise ValueError( + f"Activation function {self.activation_function_name} does not require additional arguments." + ) + return self + + def return_activation_function(self): + from modelforge.potential.utils import ACTIVATION_FUNCTIONS + + if self.activation_function_arguments is not None: + return ACTIVATION_FUNCTIONS[self.activation_function_name]( + **self.activation_function_arguments.model_dump(exclude_unset=True) + ) + return ACTIVATION_FUNCTIONS[self.activation_function_name]() + + @computed_field + @property + def activation_function(self) -> Type[torch.nn.Module]: + return self.return_activation_function() + + +# these will all be set by default to false such that we do not need to define +# unused post processing operations in the datafile +class GeneralPostProcessingOperation(ParametersBase): + calculate_molecular_self_energy: bool = False + calculate_atomic_self_energy: bool = False + + +class PerAtomEnergy(ParametersBase): + normalize: bool = False + from_atom_to_system_reduction: bool = False + keep_per_atom_property: bool = False + + +class PerAtomCharge(ParametersBase): + conserve: bool = True + conserve_strategy: str = "default" + + +class ElectrostaticPotential(ParametersBase): + electrostatic_strategy: str = "coulomb" + maximum_interaction_radius: float = 0.5 + + converted_units = field_validator( + "maximum_interaction_radius", + mode="before", + )(_convert_str_or_unit_to_unit_length) + + +class PostProcessingParameter(ParametersBase): + properties_to_process: List[str] + per_atom_energy: PerAtomEnergy = PerAtomEnergy() + per_atom_charge: PerAtomCharge = PerAtomCharge() + electrostatic_potential: ElectrostaticPotential = ElectrostaticPotential() + general_postprocessing_operation: GeneralPostProcessingOperation = ( + GeneralPostProcessingOperation() + ) + + +class AimNet2Parameters(ParametersBase): + class CoreParameter(CoreParameterBase): + number_of_radial_basis_functions: int + maximum_interaction_radius: float + number_of_interaction_modules: int + activation_function_parameter: ActivationFunctionConfig + featurization: Featurization + predicted_properties: List[str] + predicted_dim: List[int] + number_of_vector_features: int + converted_units = field_validator("maximum_interaction_radius", mode="before")( + _convert_str_or_unit_to_unit_length + ) + + potential_name: str = "AimNet2" + only_unique_pairs: bool = False + core_parameter: CoreParameter + postprocessing_parameter: PostProcessingParameter + potential_seed: Optional[int] = None + + +class ANI2xParameters(ParametersBase): + class CoreParameter(CoreParameterBase): + angle_sections: int + maximum_interaction_radius: float + minimum_interaction_radius: float + number_of_radial_basis_functions: int + maximum_interaction_radius_for_angular_features: float + minimum_interaction_radius_for_angular_features: float + angular_dist_divisions: int + activation_function_parameter: ActivationFunctionConfig + predicted_properties: List[str] + predicted_dim: List[int] + + converted_units = field_validator( + "maximum_interaction_radius", + "minimum_interaction_radius", + "maximum_interaction_radius_for_angular_features", + "minimum_interaction_radius_for_angular_features", + mode="before", + )(_convert_str_or_unit_to_unit_length) + + potential_name: str = "ANI2x" + only_unique_pairs: bool = True + core_parameter: CoreParameter + postprocessing_parameter: PostProcessingParameter + potential_seed: Optional[int] = None + + +class SchNetParameters(ParametersBase): + class CoreParameter(CoreParameterBase): + number_of_radial_basis_functions: int + maximum_interaction_radius: float + number_of_interaction_modules: int + number_of_filters: int + shared_interactions: bool + activation_function_parameter: ActivationFunctionConfig + featurization: Featurization + predicted_properties: List[str] + predicted_dim: List[int] + + converted_units = field_validator("maximum_interaction_radius", mode="before")( + _convert_str_or_unit_to_unit_length + ) + + potential_name: str = "SchNet" + only_unique_pairs: bool = False + core_parameter: CoreParameter + postprocessing_parameter: PostProcessingParameter + potential_seed: int = -1 + + +class TensorNetParameters(ParametersBase): + class CoreParameter(CoreParameterBase): + number_of_per_atom_features: int + number_of_interaction_layers: int + number_of_radial_basis_functions: int + maximum_interaction_radius: float + minimum_interaction_radius: float + maximum_atomic_number: int + equivariance_invariance_group: str + activation_function_parameter: ActivationFunctionConfig + predicted_properties: List[str] + predicted_dim: List[int] + + converted_units = field_validator( + "maximum_interaction_radius", "minimum_interaction_radius", mode="before" + )(_convert_str_or_unit_to_unit_length) + + potential_name: str = "TensorNet" + only_unique_pairs: bool = False + core_parameter: CoreParameter + postprocessing_parameter: PostProcessingParameter + potential_seed: Optional[int] = None + + +class PaiNNParameters(ParametersBase): + class CoreParameter(CoreParameterBase): + + number_of_radial_basis_functions: int + maximum_interaction_radius: float + number_of_interaction_modules: int + shared_interactions: bool + shared_filters: bool + featurization: Featurization + activation_function_parameter: ActivationFunctionConfig + predicted_properties: List[str] + predicted_dim: List[int] + + converted_units = field_validator("maximum_interaction_radius", mode="before")( + _convert_str_or_unit_to_unit_length + ) + + potential_name: str = "PaiNN" + only_unique_pairs: bool = False + core_parameter: CoreParameter + postprocessing_parameter: PostProcessingParameter + potential_seed: Optional[int] = None + + +class PhysNetParameters(ParametersBase): + class CoreParameter(CoreParameterBase): + + number_of_radial_basis_functions: int + maximum_interaction_radius: float + number_of_interaction_residual: int + number_of_modules: int + featurization: Featurization + activation_function_parameter: ActivationFunctionConfig + predicted_properties: List[str] + predicted_dim: List[int] + + converted_units = field_validator("maximum_interaction_radius", mode="before")( + _convert_str_or_unit_to_unit_length + ) + + potential_name: str = "PhysNet" + only_unique_pairs: bool = False + core_parameter: CoreParameter + postprocessing_parameter: PostProcessingParameter + potential_seed: Optional[int] = None + + +class SAKEParameters(ParametersBase): + class CoreParameter(CoreParameterBase): + + number_of_radial_basis_functions: int + maximum_interaction_radius: float + number_of_interaction_modules: int + number_of_spatial_attention_heads: int + featurization: Featurization + activation_function_parameter: ActivationFunctionConfig + predicted_properties: List[str] + predicted_dim: List[int] + + converted_units = field_validator("maximum_interaction_radius", mode="before")( + _convert_str_or_unit_to_unit_length + ) + + potential_name: str = "SAKE" + only_unique_pairs: bool = False + core_parameter: CoreParameter + postprocessing_parameter: PostProcessingParameter + potential_seed: Optional[int] = None diff --git a/modelforge/potential/physnet.py b/modelforge/potential/physnet.py index 5427f938..770623b7 100644 --- a/modelforge/potential/physnet.py +++ b/modelforge/potential/physnet.py @@ -1,638 +1,614 @@ -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict, Optional, Union +""" +Implementation of the PhysNet neural network potential. +""" + +from typing import Dict import torch from loguru import logger as log -from openff.units import unit from torch import nn -from .models import InputPreparation, NNPInput, BaseNetwork, CoreNetwork - -from modelforge.potential.utils import NeuralNetworkData - -if TYPE_CHECKING: - from modelforge.dataset.dataset import NNPInput - - from .models import PairListOutputs - - -@dataclass -class PhysNetNeuralNetworkData(NeuralNetworkData): - """ - A dataclass to structure the inputs for PhysNet-based neural network potentials, - facilitating the efficient and structured representation of atomic systems for - energy computation and property prediction within the PhysNet framework. - - Attributes - ---------- - f_ij : Optional[torch.Tensor] - A tensor representing the radial basis function (RBF) expansion applied to distances between atom pairs, - capturing the local chemical environment. Will be added after initialization. Shape: [num_pairs, num_rbf]. - number_of_atoms : int - An integer indicating the number of atoms in the batch. - atomic_embedding : torch.Tensor - A 2D tensor containing embeddings or features for each atom, derived from atomic numbers or other properties. - Shape: [num_atoms, embedding_dim]. - - Notes - ----- - The `PhysNetNeuralNetworkInput` class encapsulates essential geometric and chemical information required by - the PhysNet model to predict system energies and properties. It includes information on atomic positions, types, - and connectivity, alongside derived features such as radial basis functions (RBF) for detailed representation - of atomic environments. This structured input format ensures that all relevant data is readily available for - the PhysNet model, supporting its complex network architecture and computation requirements. - - Examples - -------- - >>> physnet_input = PhysNetNeuralNetworkInput( - ... atomic_numbers=torch.tensor([1, 6, 6, 8]), - ... positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]]), - ... atomic_subsystem_indices=torch.tensor([0, 0, 0, 0]), - ... total_charge=torch.tensor([0.0]), - ... pair_indices=torch.tensor([[0, 1], [0, 2], [1, 2]]), - ... d_ij=torch.tensor([1.0, 1.0, 1.0]), - ... r_ij=torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]), - ... f_ij=torch.randn(3, 4), # Radial basis function expansion - ... number_of_atoms=torch.tensor([4]), - ... atomic_embedding=torch.randn(4, 5) # Example atomic embeddings/features - ... ) - """ - - atomic_embedding: torch.Tensor - f_ij: Optional[torch.Tensor] = field(default=None) + +from modelforge.utils.prop import NNPInput +from modelforge.potential.neighbors import PairlistData +from .utils import Dense class PhysNetRepresentation(nn.Module): def __init__( self, - cutoff: unit = 5 * unit.angstrom, - number_of_radial_basis_functions: int = 16, + maximum_interaction_radius: float, + number_of_radial_basis_functions: int, + featurization_config: Dict[str, Dict[str, int]], ): """ - Representation module for the PhysNet potential, handling the generation of - the radial basis functions (RBFs) with a cutoff. + Representation module for PhysNet, generating radial basis functions + (RBFs) and atomic embeddings with a cutoff for atomic interactions. Parameters ---------- - cutoff : openff.units.unit.Quantity, default=5*unit.angstrom + maximum_interaction_radius : float The cutoff distance for interactions. - number_of_radial_basis_functions : int, default=16 + number_of_radial_basis_functions : int Number of radial basis functions to use. + featurization_config : Dict[str, Dict[str, int]] + Configuration for atomic feature generation. """ super().__init__() - # cutoff - from modelforge.potential import CosineCutoff - - self.cutoff_module = CosineCutoff(cutoff) + # Initialize the cutoff function and radial basis function modules + from modelforge.potential import ( + CosineAttenuationFunction, + PhysNetRadialBasisFunction, + FeaturizeInput, + ) - # radial symmetry function - from .utils import PhysNetRadialBasisFunction + self.cutoff_module = CosineAttenuationFunction(maximum_interaction_radius) + self.featurize_input = FeaturizeInput(featurization_config) + # Radial symmetry function using PhysNet radial basis expansion self.radial_symmetry_function_module = PhysNetRadialBasisFunction( number_of_radial_basis_functions=number_of_radial_basis_functions, - max_distance=cutoff, + max_distance=maximum_interaction_radius, dtype=torch.float32, ) - def forward(self, d_ij: torch.Tensor) -> torch.Tensor: + def forward( + self, data: NNPInput, pairlist_output: PairlistData + ) -> Dict[str, torch.Tensor]: """ - Forward pass of the representation module. + Forward pass for the representation module, generating RBFs and + atomic embeddings. Parameters ---------- - d_ij : torch.Tensor - pairwise distances between atoms, shape (n_pairs). + data : NNPInput + Input data containing atomic positions, atomic numbers, etc. + pairlist_output : PairlistData + Output from the pairlist module containing distances and pair indices. Returns ------- - torch.Tensor - The radial basis function expansion applied to the input distances, - shape (n_pairs, n_gaussians), after applying the cutoff function. + Dict[str, torch.Tensor] + A dictionary with RBFs and atomic embeddings. """ + # Generate radial basis function expansion and apply cutoff + f_ij = self.radial_symmetry_function_module(pairlist_output.d_ij).squeeze() + f_ij = torch.mul(f_ij, self.cutoff_module(pairlist_output.d_ij)) + + return { + "f_ij": f_ij, + "atomic_embedding": self.featurize_input(data), + } - f_ij = self.radial_symmetry_function_module(d_ij).squeeze() - cutoff = self.cutoff_module(d_ij) - f_ij = torch.mul(f_ij, cutoff) - return f_ij +class PhysNetResidual(nn.Module): -class GatingModule(nn.Module): - def __init__(self, number_of_atom_basis: int): + def __init__( + self, + input_dim: int, + output_dim: int, + activation_function: torch.nn.Module, + ): """ - Initializes a gating module that - optionally applies a sigmoid gating mechanism to input features. + Residual block for PhysNet, refining atomic feature vectors by adding + a residual component. - Parameters: - ----------- + Parameters + ---------- input_dim : int - The dimensionality of the input (and output) features. + Dimensionality of the input feature vector. + output_dim : int + Dimensionality of the output feature vector, which typically matches the + input dimension. + activation_function : Type[torch.nn.Module] + The activation function to be used in the residual block. """ super().__init__() - self.gate = nn.Parameter(torch.ones(number_of_atom_basis)) - def forward(self, x: torch.Tensor, activation_fn: bool = False) -> torch.Tensor: - """ - Apply gating to the input tensor. - - Parameters: - ----------- - x : torch.Tensor - The input tensor to gate. - - Returns: - -------- - torch.Tensor - The gated input tensor. - """ - gating_signal = torch.sigmoid(self.gate) - return gating_signal * x - - -from .utils import ShiftedSoftplus, Dense - - -class PhysNetResidual(nn.Module): - """ - Implements a preactivation residual block as described in Equation 4 of the PhysNet paper. - - The block refines atomic feature vectors by adding a residual component computed through - two linear transformations and a non-linear activation function (Softplus). This setup - enhances gradient flow and supports effective deep network training by employing a - preactivation scheme. - - Parameters: - ----------- - input_dim: int - Dimensionality of the input feature vector. - output_dim: int - Dimensionality of the output feature vector, which typically matches the input dimension. - """ - - def __init__(self, input_dim: int, output_dim: int): - super().__init__() - self.dense = Dense(input_dim, output_dim, activation=ShiftedSoftplus()) - self.residual = Dense(output_dim, output_dim) + # Define the dense layers and residual connection with activation + self.dense = nn.Sequential( + activation_function, + Dense(input_dim, output_dim, activation_function), + Dense(output_dim, output_dim), + ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ - Forward pass of the ResidualBlock. + Forward pass of the residual block. - Parameters: - ----------- - x: torch.Tensor - Input tensor containing feature vectors of atoms. + Parameters + ---------- + x : torch.Tensor + Input feature tensor. - Returns: - -------- + Returns + ------- torch.Tensor - Output tensor after applying the residual block operations. + Output tensor after applying residual connection. """ - # update x with residual - return x + self.residual(self.dense(x)) + return x + self.dense(x) class PhysNetInteractionModule(nn.Module): def __init__( self, - number_of_atom_features: int = 64, - number_of_radial_basis_functions: int = 16, - number_of_interaction_residual: int = 3, + number_of_per_atom_features: int, + number_of_radial_basis_functions: int, + number_of_interaction_residual: int, + activation_function: torch.nn.Module, ): """ - Module to compute interaction terms based on atomic distances and features. + Module for computing interaction terms based on atomic distances and features. Parameters ---------- - number_of_atom_features : int, default=64 + number_of_per_atom_features : int Dimensionality of the atomic embeddings. - number_of_radial_basis_functions : int, default=16 - Specifies the number of basis functions for the Gaussian Logarithm Attention, - essentially defining the output feature dimension for attention-weighted interactions. + number_of_radial_basis_functions : int + Number of radial basis functions for the interaction. + number_of_interaction_residual : int + Number of residual blocks in the interaction module. + activation_function : torch.nn.Module + The activation function to be used in the interaction module. """ super().__init__() - from .utils import ShiftedSoftplus, Dense + from .utils import DenseWithCustomDist - self.attention_mask = Dense( + # Initialize activation function + self.activation_function = activation_function + + # Initialize attention mask + self.attention_mask = DenseWithCustomDist( number_of_radial_basis_functions, - number_of_atom_features, + number_of_per_atom_features, bias=False, weight_init=torch.nn.init.zeros_, ) - self.activation_function = ShiftedSoftplus() - # Networks for processing atomic embeddings of i and j atoms + # Initialize networks for processing atomic embeddings of i and j atoms self.interaction_i = Dense( - number_of_atom_features, - number_of_atom_features, - activation=self.activation_function, + number_of_per_atom_features, + number_of_per_atom_features, + activation_function=activation_function, ) self.interaction_j = Dense( - number_of_atom_features, - number_of_atom_features, - activation=self.activation_function, + number_of_per_atom_features, + number_of_per_atom_features, + activation_function=activation_function, ) - self.process_v = Dense(number_of_atom_features, number_of_atom_features) + # Initialize processing network + self.process_v = Dense(number_of_per_atom_features, number_of_per_atom_features) - # Residual block + # Initialize residual blocks self.residuals = nn.ModuleList( [ - PhysNetResidual(number_of_atom_features, number_of_atom_features) + PhysNetResidual( + input_dim=number_of_per_atom_features, + output_dim=number_of_per_atom_features, + activation_function=activation_function, + ) for _ in range(number_of_interaction_residual) ] ) - # Gating - self.gate = nn.Parameter(torch.ones(number_of_atom_features)) + # Gating and dropout layers + self.gate = nn.Parameter(torch.ones(number_of_per_atom_features)) self.dropout = nn.Dropout(p=0.05) - def forward(self, data: PhysNetNeuralNetworkData) -> torch.Tensor: + def forward(self, data: Dict[str, torch.Tensor]) -> torch.Tensor: """ - Processes input tensors through the interaction module, applying - Gaussian Logarithm Attention to modulate the influence of pairwise distances - on the interaction features, followed by aggregation to update atomic embeddings. + Forward pass for the interaction module. Parameters ---------- - inputs : PhysNetNeuralNetworkInput + data : Dict[str, torch.Tensor] + Input data including pairwise distances, pair indices, and atomic + embeddings. Returns ------- torch.Tensor - Updated atomic feature representations incorporating interaction information. + Updated atomic embeddings after interaction computation. """ - # Equation 6: Formation of the Proto-Message ṽ_i for an Atom i - # ṽ_i = σ(Wl_I * x_i^l + bl_I) + Σ_j (G_g * Wl * (σ(σl_J * x_j^l + bl_J)) * g(r_ij)) - # Equation 6 implementation overview: - # ṽ_i = x_i_prime + sum_over_j(x_j_prime * f_ij_prime) - # where: - # - x_i_prime and x_j_prime are the features of atoms i and j, respectively, processed through separate networks. - # - f_ij_prime represents the modulated radial basis functions (f_ij) by the Gaussian Logarithm Attention weights. - - # extract relevant variables - idx_i, idx_j = data.pair_indices - f_ij = data.f_ij - x = data.atomic_embedding - - # # Apply activation to atomic embeddings - xa = self.dropout(self.activation_function(x)) - - # calculate attention weights and - # transform to - # input shape: (number_of_pairs, number_of_radial_basis_functions) - # output shape: (number_of_pairs, number_of_atom_features) - g = self.attention_mask(f_ij) - - # Calculate contribution of central atom - x_i = self.interaction_i(xa) - # Calculate contribution of neighbor atom - x_j = self.interaction_j(xa) - # Gather the results according to idx_j - x_j = x_j[idx_j] - # Multiply the gathered features by g - x_j_modulated = x_j * g - # Aggregate modulated contributions for each atom i - x_j_prime = torch.zeros_like(x_i) - x_j_prime.scatter_add_( - 0, idx_i.unsqueeze(-1).expand(-1, x_j_modulated.size(-1)), x_j_modulated + + idx_i, idx_j = data["pair_indices"].unbind() + + # Apply activation to atomic embeddings + # first term in equation 6 in the PhysNet paper + embedding_atom_i = self.activation_function( + self.interaction_i(data["atomic_embedding"]) + ) # shape (nr_of_atoms_in_batch, atomic_embedding_dim) + + # second term in equation 6 in the PhysNet paper + # apply attention mask G to radial basis functions f_ij + g = self.attention_mask( + data["f_ij"] + ) # shape (nr_of_atom_pairs_in_batch, atomic_embedding_dim) + # calculate the updated embedding for atom j + # NOTE: this changes the 2nd dimension from number_of_radial_basis_functions to atomic_embedding_dim + embedding_atom_j = self.activation_function( + self.interaction_j(data["atomic_embedding"])[ + idx_j + ] # NOTE this is the same as the embedding_atom_i, but then we are selecting the embedding of atom j + # shape (nr_of_atom_pairs_in_batch, atomic_embedding_dim) + ) + updated_embedding_atom_j = torch.mul( + g, embedding_atom_j + ) # element-wise multiplication + + # Sum over contributions from atom j as function of embedding of atom i + # and attention mask G(f_ij) + embedding_atom_i.scatter_add_( + 0, + idx_i.unsqueeze(-1).expand(-1, updated_embedding_atom_j.shape[-1]), + updated_embedding_atom_j, ) - # Draft proto message v_tilde - m = x_i + x_j_prime - # shape of m (nr_of_atoms_in_batch, 1) - # Equation 4: Preactivation Residual Block Implementation - # xl+2_i = xl_i + Wl+1 * sigma(Wl * xl_i + bl) + bl+1 + # apply residual blocks for residual in self.residuals: - m = residual( - m + embedding_atom_i = residual( + embedding_atom_i ) # shape (nr_of_atoms_in_batch, number_of_radial_basis_functions) - m = self.activation_function(m) - x = self.gate * x + self.process_v(m) - return x + + # Apply dropout to the embedding after the residuals + embedding_atom_i = self.dropout(embedding_atom_i) + + # eqn 5 in the PhysNet paper + embedding_atom_i = self.gate * data["atomic_embedding"] + self.process_v( + self.activation_function(embedding_atom_i) + ) + return embedding_atom_i class PhysNetOutput(nn.Module): def __init__( self, - number_of_atom_features: int, - number_of_atomic_properties: int = 2, - number_of_residuals_in_output: int = 2, + number_of_per_atom_features: int, + number_of_atomic_properties: int, + number_of_residuals_in_output: int, + activation_function: torch.nn.Module, ): - from .utils import Dense + """ + Output module for the PhysNet model, responsible for generating predictions + from atomic embeddings. + + Parameters + ---------- + number_of_per_atom_features : int + Dimensionality of the atomic embeddings. + number_of_atomic_properties : int + Number of atomic properties to predict. + number_of_residuals_in_output : int + Number of residual blocks in the output module. + activation_function : torch.nn.Module + Activation function to apply in the output module. + """ + from .utils import DenseWithCustomDist super().__init__() + # Initialize residual blocks self.residuals = nn.Sequential( *[ - PhysNetResidual(number_of_atom_features, number_of_atom_features) + PhysNetResidual( + number_of_per_atom_features, + number_of_per_atom_features, + activation_function, + ) for _ in range(number_of_residuals_in_output) ] ) - self.output = Dense( - number_of_atom_features, + # Output layer for predicting atomic properties + self.output = DenseWithCustomDist( + number_of_per_atom_features, number_of_atomic_properties, - weight_init=torch.nn.init.zeros_, - bias=False, + weight_init=torch.nn.init.zeros_, # NOTE: the result of this initialization is that before the first parameter update the output is zero ) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the output module. + + Parameters + ---------- + x : torch.Tensor + Input tensor containing atomic feature vectors. + + Returns + ------- + torch.Tensor + Predicted atomic properties. + """ x = self.output(self.residuals(x)) return x class PhysNetModule(nn.Module): + def __init__( self, - number_of_atom_features: int = 64, - number_of_radial_basis_functions: int = 16, - number_of_interaction_residual: int = 2, + number_of_per_atom_features: int, + number_of_radial_basis_functions: int, + number_of_interaction_residual: int, + activation_function: torch.nn.Module, + number_of_residuals_in_output: int, + number_of_atomic_properties: int, ): """ - Wrapper module that combines the PhysNetInteraction, PhysNetResidual, and - PhysNetOutput classes into a single module. This serves as the building - block for the PhysNet model. + Wrapper for the PhysNet interaction and output modules. - This is a skeletal implementation that needs to be expanded upon. + Parameters + ---------- + number_of_per_atom_features : int + Dimensionality of the atomic embeddings. + number_of_radial_basis_functions : int + Number of radial basis functions. + number_of_interaction_residual : int + Number of residual blocks in the interaction module. + activation_function : torch.nn.Module + Activation function to apply in the modules. + number_of_residuals_in_output : int + Number of residual blocks in the output module. + number_of_atomic_properties : int + Number of atomic properties to predict. """ super().__init__() - # this class combines the PhysNetInteraction, PhysNetResidual and - # PhysNetOutput class - + # Initialize interaction module self.interaction = PhysNetInteractionModule( - number_of_atom_features=number_of_atom_features, + number_of_per_atom_features=number_of_per_atom_features, number_of_radial_basis_functions=number_of_radial_basis_functions, number_of_interaction_residual=number_of_interaction_residual, + activation_function=activation_function, ) + # Initialize output module self.output = PhysNetOutput( - number_of_atom_features=number_of_atom_features, - number_of_atomic_properties=2, + number_of_per_atom_features=number_of_per_atom_features, + number_of_atomic_properties=number_of_atomic_properties, + number_of_residuals_in_output=number_of_residuals_in_output, + activation_function=activation_function, ) - def forward(self, data: PhysNetNeuralNetworkData) -> Dict[str, torch.Tensor]: + def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Forward pass for the PhysNet module. + + Parameters + ---------- + data : Dict[str, torch.Tensor] + Input data containing atomic features and pairwise information. + + Returns + ------- + Dict[str, torch.Tensor] + Dictionary containing predictions and updated embeddings. """ - # The PhysNet module is a sequence of interaction modules and residual modules. - # x_1, ..., x_N - # | - # v - # ┌─────────────┐ - # │ interaction │ <-- g(d_ij) - # └─────────────┘ - # │ - # v - # ┌───────────┐ - # │ residual │ - # └───────────┘ - # ┌───────────┐ - # │ residual │ - # └───────────┘ - # ┌───────────┐ │ - # │ output │<-----│ - # └───────────┘ │ - # v - - # calculate the interaction - v = self.interaction(data) - - # calculate the module output - prediction = self.output(v) + # Update embeddings via interaction + updated_embedding = self.interaction(data) + + # Generate atomic property predictions + prediction = self.output(updated_embedding) return { "prediction": prediction, - "updated_embedding": v, # input for next module + "updated_embedding": updated_embedding, # input for next module } -class PhysNetCore(CoreNetwork): +from typing import List + + +class PhysNetCore(torch.nn.Module): + def __init__( self, - max_Z: int, - cutoff: unit.Quantity, - number_of_atom_features: int, + featurization: Dict[str, Dict[str, int]], + maximum_interaction_radius: float, number_of_radial_basis_functions: int, number_of_interaction_residual: int, number_of_modules: int, + activation_function_parameter: Dict[str, str], + predicted_properties: List[str], + predicted_dim: List[int], + potential_seed: int = -1, ) -> None: """ - Implementation of the PhysNet neural network potential. + Core implementation of PhysNet, combining multiple PhysNet modules. Parameters ---------- - max_Z : int, default=100 - Maximum atomic number to be embedded. - number_of_atom_features : int, default=64 - Dimension of the embedding vectors for atomic numbers. - cutoff : openff.units.unit.Quantity, default=5*unit.angstrom - The cutoff distance for interactions. - number_of_modules : int, default=2( + featurization : Dict[str, Dict[str, int]] + Configuration for atomic feature generation. + maximum_interaction_radius : float + Cutoff distance for atomic interactions. + number_of_radial_basis_functions : int + Number of radial basis functions for interaction computation. + number_of_interaction_residual : int + Number of residual blocks in the interaction modules. + number_of_modules : int + Number of PhysNet modules to stack. + activation_function_parameter : Dict[str, str] + Configuration for the activation function. + predicted_properties : List[str] + List of properties to predict. + predicted_dim : List[int] + List of dimensions corresponding to the predicted properties. + potential_seed : int, optional + Seed for random number generation, by default -1. """ + from modelforge.utils.misc import seed_random_number - log.debug("Initializing PhysNet model.") - super().__init__() + if potential_seed != -1: + seed_random_number(potential_seed) - # embedding - from modelforge.potential.utils import Embedding + super().__init__() + self.activation_function = activation_function_parameter["activation_function"] - self.embedding_module = Embedding(max_Z, number_of_atom_features) + log.debug("Initializing the PhysNet architecture.") + # Initialize atomic feature dimensions and representation module + number_of_per_atom_features = int( + featurization["atomic_number"]["number_of_per_atom_features"] + ) self.physnet_representation_module = PhysNetRepresentation( - cutoff=cutoff, + maximum_interaction_radius=maximum_interaction_radius, number_of_radial_basis_functions=number_of_radial_basis_functions, + featurization_config=featurization, ) # initialize the PhysNetModule building blocks from torch.nn import ModuleList + self.output_dim = int(sum(predicted_dim)) + # Stack multiple PhysNet modules self.physnet_module = ModuleList( [ PhysNetModule( - number_of_atom_features, + number_of_per_atom_features, number_of_radial_basis_functions, number_of_interaction_residual, + number_of_residuals_in_output=2, + number_of_atomic_properties=self.output_dim, + activation_function=self.activation_function, ) for _ in range(number_of_modules) ] ) - self.atomic_scale = nn.Parameter(torch.ones(max_Z, 2)) - self.atomic_shift = nn.Parameter(torch.zeros(max_Z, 2)) - - def _model_specific_input_preparation( - self, data: "NNPInput", pairlist_output: "PairListOutputs" - ) -> PhysNetNeuralNetworkData: - # Perform atomic embedding - atomic_embedding = self.embedding_module(data.atomic_numbers) - # Z_i, ..., Z_N - # - # │ - # ∨ - # ┌────────────┐ - # │ embedding │ - # └────────────┘ - - number_of_atoms = data.atomic_numbers.shape[0] - - nnp_input = PhysNetNeuralNetworkData( - pair_indices=pairlist_output.pair_indices, - d_ij=pairlist_output.d_ij, - r_ij=pairlist_output.r_ij, - f_ij=None, - number_of_atoms=number_of_atoms, - positions=data.positions, - atomic_numbers=data.atomic_numbers, - atomic_subsystem_indices=data.atomic_subsystem_indices, - total_charge=data.total_charge, - atomic_embedding=atomic_embedding, # atom embedding + # Define learnable atomic shift and scale per atomic property + maximum_atomic_number = int( + featurization["atomic_number"]["maximum_atomic_number"] + ) + + self.atomic_scale = nn.Parameter( + torch.ones( + maximum_atomic_number, + len(predicted_properties), + ) + ) + self.atomic_shift = nn.Parameter( + torch.zeros( + maximum_atomic_number, + len(predicted_properties), + ) ) - return nnp_input + self.predicted_properties = predicted_properties + self.predicted_dim = predicted_dim def compute_properties( - self, data: PhysNetNeuralNetworkData + self, + data: NNPInput, + pairlist_output: PairlistData, ) -> Dict[str, torch.Tensor]: """ - Calculate the energy for a given input batch. + Compute properties for a given input batch. + Parameters ---------- - inputs : PhysNetNeutralNetworkInput + data : NNPInput + Input data containing atomic features and pairwise information. + pairlist_output : PairlistData + Output from the pairlist module. Returns ------- - torch.Tensor - Calculated energies; shape (nr_systems,). + Dict[str, torch.Tensor] + Calculated atomic properties. """ - # Computed representation - data.f_ij = self.physnet_representation_module(data.d_ij).squeeze( - 1 - ) # shape: (n_pairs, number_of_radial_basis_functions) - nr_of_atoms_in_batch = data.number_of_atoms - - # d_i, ..., d_N - # - # │ - # V - # ┌────────────┐ - # │ RBF │ - # └────────────┘ - - # see https://doi.org/10.1021/acs.jctc.9b00181 - # in the following we are implementing the calculations analoguous - # to the modules outlined in Figure 1 - - # NOTE: both embedding and f_ij (the output of the Radial Symmetry Function) are - # stored in `inputs` - # inputs are the embedding vectors and f_ij - # the embedding vector will get updated in each pass through the modules - - # ┌────────────┐ ┌────────────┐ - # │ embedding │ │ RBF │ - # └────────────┘ └────────────┘ - # | │ - # ┌───────────────┐ │ - # | <-- | module 1 │ <--│ - # | └────────────---┘ │ - # | | │ - # E_1, ..., E_N (+) V │ - # | ┌───────────────┐ │ - # | <-- | module 2 │ <--│ - # └────────────---┘ - - # the atomic energies are accumulated in per_atom_energies - prediction_i = torch.zeros( - (nr_of_atoms_in_batch, 2), - device=data.d_ij.device, + # Compute representations for the input data + representation = self.physnet_representation_module(data, pairlist_output) + + # Initialize tensor to store accumulated property predictions + nr_of_atoms_in_batch = data.atomic_numbers.shape[0] + per_atom_property_prediction = torch.zeros( + (nr_of_atoms_in_batch, self.output_dim), + device=data.atomic_numbers.device, ) + # Pass through stacked PhysNet modules + module_data: Dict[str, torch.Tensor] = { + "pair_indices": pairlist_output.pair_indices, + "f_ij": representation["f_ij"], + "atomic_embedding": representation["atomic_embedding"], + } + for module in self.physnet_module: - output_of_module = module(data) - # accumulate output for atomic energies - prediction_i += output_of_module["prediction"] + module_output = module(module_data) + # accumulate output for atomic properties + per_atom_property_prediction = ( + per_atom_property_prediction + module_output["prediction"] + ) # update embedding for next module - data.atomic_embedding = output_of_module["updated_embedding"] - - prediction_i_shifted_scaled = ( - self.atomic_shift[data.atomic_numbers] - + prediction_i * self.atomic_scale[data.atomic_numbers] - ) + module_data["atomic_embedding"] = module_output["updated_embedding"] - # sum over atom features - E_i = prediction_i_shifted_scaled[:, 0] # shape(nr_of_atoms, 1) - q_i = prediction_i_shifted_scaled[:, 1] # shape(nr_of_atoms, 1) - - output = { - "per_atom_energy": E_i.contiguous(), # reshape memory mapping for JAX/dlpack - "q_i": q_i.contiguous(), + # Return computed properties and representations + return { + "per_atom_scalar_representation": module_output["updated_embedding"], + "per_atom_prediction": per_atom_property_prediction, "atomic_subsystem_indices": data.atomic_subsystem_indices, "atomic_numbers": data.atomic_numbers, } - return output - - -from .models import InputPreparation, NNPInput, BaseNetwork -from typing import List - - -class PhysNet(BaseNetwork): - def __init__( - self, - max_Z: int, - cutoff: Union[unit.Quantity, str], - number_of_atom_features: int, - number_of_radial_basis_functions: int, - number_of_interaction_residual: int, - number_of_modules: int, - postprocessing_parameter: Dict[str, Dict[str, bool]], - dataset_statistic: Optional[Dict[str, float]] = None, - ) -> None: + def _aggregate_results( + self, outputs: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: """ - Unke, O. T. and Meuwly, M. "PhysNet: A Neural Network for Predicting Energies, - Forces, Dipole Moments and Partial Charges" arxiv:1902.08408 (2019). + Aggregate atomic property predictions into the final results. + Parameters + ---------- + per_atom_property_prediction : torch.Tensor + Tensor of predicted per-atom properties. + data : NNPInput + Input data containing atomic numbers, etc. + Returns + ------- + Dict[str, torch.Tensor] + Aggregated results containing per-atom predictions and other properties. """ - from modelforge.utils.units import _convert - self.only_unique_pairs = False # NOTE: for pairlist - super().__init__( - dataset_statistic=dataset_statistic, - postprocessing_parameter=postprocessing_parameter, - cutoff=_convert(cutoff), - ) - - self.core_module = PhysNetCore( - max_Z=max_Z, - cutoff=_convert(cutoff), - number_of_atom_features=number_of_atom_features, - number_of_radial_basis_functions=number_of_radial_basis_functions, - number_of_interaction_residual=number_of_interaction_residual, - number_of_modules=number_of_modules, + per_atom_prediction = outputs.pop("per_atom_prediction") + # Apply atomic-specific scaling and shifting to the predicted properties + atomic_numbers = outputs["atomic_numbers"] + per_atom_prediction = ( + self.atomic_shift[atomic_numbers] + + per_atom_prediction * self.atomic_scale[atomic_numbers] + ) # NOTE: Questions: is this appropriate for partial charges? + + # Split predictions for each property + split_tensors = torch.split(per_atom_prediction, self.predicted_dim, dim=1) + outputs.update( + { + label: tensor + for label, tensor in zip(self.predicted_properties, split_tensors) + } ) + return outputs - def _config_prior(self): - log.info("Configuring SchNet model hyperparameter prior distribution") - from modelforge.utils.io import import_ - - tune = import_("ray").tune - # from ray import tune - - from modelforge.potential.utils import shared_config_prior + def forward( + self, data: NNPInput, pairlist_output: PairlistData + ) -> Dict[str, torch.Tensor]: + """ + Forward pass through the entire PhysNet architecture. - prior = { - "number_of_atom_features": tune.randint(2, 256), - "number_of_modules": tune.randint(2, 8), - "number_of_interaction_residual": tune.randint(2, 5), - "cutoff": tune.uniform(5, 10), - "number_of_radial_basis_functions": tune.randint(8, 32), - } - prior.update(shared_config_prior()) - return prior + Parameters + ---------- + data : NNPInput + Input data containing atomic features and pairwise information. + pairlist_output : PairlistData + Pairwise information from the pairlist module. - def combine_per_atom_properties( - self, values: Dict[str, torch.Tensor] - ) -> torch.Tensor: - return values + Returns + ------- + Dict[str, torch.Tensor] + Dictionary with the predicted atomic properties. + """ + # perform the forward pass implemented in the subclass + outputs = self.compute_properties(data, pairlist_output) + # Aggregate and return the results + return self._aggregate_results(outputs) diff --git a/modelforge/potential/potential.py b/modelforge/potential/potential.py new file mode 100644 index 00000000..01b7f35f --- /dev/null +++ b/modelforge/potential/potential.py @@ -0,0 +1,824 @@ +""" +This module contains the base classes for the neural network potentials. +""" + +from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Tuple, TypeVar, Union + +import lightning as pl +import torch +from loguru import logger as log +from openff.units import unit +from modelforge.potential.neighbors import PairlistData + +from modelforge.dataset.dataset import DatasetParameters +from modelforge.utils.prop import NNPInput +from modelforge.potential.parameters import ( + AimNet2Parameters, + ANI2xParameters, + PaiNNParameters, + PhysNetParameters, + SAKEParameters, + SchNetParameters, + TensorNetParameters, +) +from modelforge.train.parameters import RuntimeParameters, TrainingParameters + +# Define a TypeVar that can be one of the parameter models +T_NNP_Parameters = TypeVar( + "T_NNP_Parameters", + ANI2xParameters, + SAKEParameters, + SchNetParameters, + PhysNetParameters, + PaiNNParameters, + TensorNetParameters, + AimNet2Parameters, +) + + +from typing import Callable, Literal, Optional, Union, TYPE_CHECKING + +if TYPE_CHECKING: + from modelforge.train.training import PotentialTrainer + +import numpy as np + + +class JAXModel: + """ + A wrapper for calling a JAX function with predefined parameters and buffers. + + Attributes + ---------- + jax_fn : Callable + The JAX function to be called. + parameter : np.ndarray + Parameters required by the JAX function. + buffer : Any + Buffers required by the JAX function. + name : str + Name of the model. + """ + + def __init__( + self, jax_fn: Callable, parameter: np.ndarray, buffer: np.ndarray, name: str + ): + self.jax_fn = jax_fn + self.parameter = parameter + self.buffer = buffer + self.name = name + + def __call__(self, data: NamedTuple): + """Calls the JAX function using the stored parameters and buffers along with additional data. + + Parameters + ---------- + data : NamedTuple + Data to be passed to the JAX function. + + Returns + ------- + Any + The result of the JAX function. + """ + + return self.jax_fn(self.parameter, self.buffer, data) + + def __repr__(self): + return f"{self.__class__.__name__} wrapping {self.name}" + + +from torch.nn import ModuleDict + +from modelforge.potential.processing import ( + CoulombPotential, + PerAtomCharge, + PerAtomEnergy, +) + + +class PostProcessing(torch.nn.Module): + + _SUPPORTED_PROPERTIES = [ + "per_atom_energy", + "per_atom_charge", + "electrostatic_potential", + "general_postprocessing_operation", + ] + + def __init__( + self, + postprocessing_parameter: Dict[str, Dict[str, bool]], + dataset_statistic: Dict[str, Dict[str, float]], + ): + """ + Handle post-processing operations on model outputs, such as + normalization and reduction. + + Parameters + ---------- + postprocessing_parameter : Dict[str, Dict[str, bool]] + A dictionary containing the postprocessing parameters for each + property. + dataset_statistic : Dict[str, Dict[str, float]] + A dictionary containing the dataset statistics for normalization and + other calculations. + """ + super().__init__() + + self._registered_properties: List[str] = [] + self.registered_chained_operations = ModuleDict() + self.dataset_statistic = dataset_statistic + properties_to_process = postprocessing_parameter["properties_to_process"] + + if "per_atom_energy" in properties_to_process: + self.registered_chained_operations["per_atom_energy"] = PerAtomEnergy( + postprocessing_parameter["per_atom_energy"], + dataset_statistic["training_dataset_statistics"], + ) + self._registered_properties.append("per_atom_energy") + assert all( + prop in PostProcessing._SUPPORTED_PROPERTIES + for prop in self._registered_properties + ) + + if "per_atom_charge" in properties_to_process: + self.registered_chained_operations["per_atom_charge"] = PerAtomCharge( + postprocessing_parameter["per_atom_charge"] + ) + self._registered_properties.append("per_atom_charge") + assert all( + prop in PostProcessing._SUPPORTED_PROPERTIES + for prop in self._registered_properties + ) + + if "electrostatic_potential" in properties_to_process: + if ( + postprocessing_parameter["electrostatic_potential"][ + "electrostatic_strategy" + ] + == "coulomb" + ): + + self.registered_chained_operations["electrostatic_potential"] = ( + CoulombPotential( + postprocessing_parameter["electrostatic_potential"][ + "maximum_interaction_radius" + ], + ) + ) + self._registered_properties.append("electrostatic_potential") + assert all( + prop in PostProcessing._SUPPORTED_PROPERTIES + for prop in self._registered_properties + ) + else: + raise NotImplementedError( + "Only Coulomb potential is supported for electrostatics." + ) + + def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Perform post-processing for all registered properties. + + Parameters + ---------- + data : Dict[str, torch.Tensor] + The model output data to be post-processed. + + Returns + ------- + Dict[str, torch.Tensor] + The post-processed data. + """ + processed_data: Dict[str, torch.Tensor] = {} + # Iterate over items in ModuleDict + for name, module in self.registered_chained_operations.items(): + + module_output = module.forward(data) + processed_data.update(module_output) + + return processed_data + + +class Potential(torch.nn.Module): + def __init__( + self, + core_network, + neighborlist, + postprocessing, + jit: bool = False, + jit_neighborlist: bool = True, + ): + """ + Neural network potential model composed of a core network, neighborlist, + and post-processing. + + Parameters + ---------- + core_network : torch.nn.Module + The core neural network used for potential energy calculation. + neighborlist : torch.nn.Module + Module for computing neighbor lists and pairwise distances. + postprocessing : torch.nn.Module + Module for handling post-processing operations. + jit : bool, optional + Whether to JIT compile the core network and post-processing + (default: False). + jit_neighborlist : bool, optional + Whether to JIT compile the neighborlist (default: True). + """ + + super().__init__() + + self.core_network = torch.jit.script(core_network) if jit else core_network + self.neighborlist = ( + torch.jit.script(neighborlist) if jit_neighborlist else neighborlist + ) + self.postprocessing = ( + torch.jit.script(postprocessing) if jit else postprocessing + ) + + def _add_total_charge( + self, core_output: Dict[str, torch.Tensor], input_data: NNPInput + ): + """ + Add the total charge to the core output. + + Parameters + ---------- + core_output : Dict[str, torch.Tensor] + The core network output. + input_data : NNPInput + The input data containing the atomic numbers and charges. + + Returns + ------- + Dict[str, torch.Tensor] + The core network output with the total charge added. + """ + # Add the total charge to the core output + core_output["per_system_total_charge"] = input_data.per_system_total_charge + return core_output + + def _add_pairlist( + self, core_output: Dict[str, torch.Tensor], pairlist_output: PairlistData + ): + """ + Add the pairlist to the core output. + + Parameters + ---------- + core_output : Dict[str, torch.Tensor] + The core network output. + pairlist_output : PairlistData + The pairlist output from the neighborlist. + + Returns + ------- + Dict[str, torch.Tensor] + The core network output with the pairlist added. + """ + # Add the pairlist to the core output + core_output["pair_indices"] = pairlist_output.pair_indices + core_output["d_ij"] = pairlist_output.d_ij + core_output["r_ij"] = pairlist_output.r_ij + return core_output + + def _remove_pairlist(self, processed_output: Dict[str, torch.Tensor]): + """ + Remove the pairlist from the core output. + + Parameters + ---------- + processed_output : Dict[str, torch.Tensor] + The postprocessed output. + + Returns + ------- + Dict[str, torch.Tensor] + The postprocessed output with the pairlist removed. + """ + # Remove the pairlist from the core output + del processed_output["pair_indices"] + del processed_output["d_ij"] + del processed_output["r_ij"] + return processed_output + + def forward(self, input_data: NNPInput) -> Dict[str, torch.Tensor]: + """ + Forward pass for the potential model, computing energy and forces. + + Parameters + ---------- + input_data : NNPInput + Input data containing atomic positions and other features. + + Returns + ------- + Dict[str, torch.Tensor] + Dictionary containing the processed output data. + """ + # Step 1: Compute pair list and distances using Neighborlist + pairlist_output = self.neighborlist.forward(input_data) + + # Step 2: Compute the core network output + core_output = self.core_network.forward(input_data, pairlist_output) + + # Step 3: Apply postprocessing using PostProcessing + core_output = self._add_total_charge(core_output, input_data) + core_output = self._add_pairlist(core_output, pairlist_output) + + processed_output = self.postprocessing.forward(core_output) + processed_output = self._remove_pairlist(processed_output) + return processed_output + + def compute_core_network_output( + self, input_data: NNPInput + ) -> Dict[str, torch.Tensor]: + """ + Compute the core network output, including energy predictions. + + Parameters + ---------- + input_data : NNPInput + Input data containing atomic positions and other features. + + Returns + ------- + Dict[str, torch.Tensor] + Tensor containing the predicted core network output. + """ + # Step 1: Compute pair list and distances using Neighborlist + pairlist_output = self.neighborlist.forward(input_data) + + # Step 2: Compute the core network output + return self.core_network.forward(input_data, pairlist_output) + + def load_state_dict( + self, + state_dict: Mapping[str, Any], + strict: bool = True, + assign: bool = False, + ): + """ + Load the state dictionary into the infenerence or training model. Note + that the Trainer class encapsulates the Training adapter (the PyTorch + Lightning module), which contains the model. When saving a state dict + from the Trainer class, you need to use `trainer.model.state_dict()` to + save the model state dict. To load this in inference mode, you can use + the `load_state_dict()` function in the Potential class. This function + can load a state dictionary into the model, and removes keys that are + specific to the training mode. + + Parameters + ---------- + state_dict : Mapping[str, Any] + The state dictionary to load. + strict : bool, optional + Whether to strictly enforce that the keys in `state_dict` match the + keys returned by this module's `state_dict()` function (default is + True). + assign : bool, optional + Whether to assign the state dictionary to the model directly + (default is False). + + Notes + ----- + This function can remove a specific prefix from the keys in the state + dictionary. It can also exclude certain keys from being loaded into the + model. + """ + + # Prefix to remove from the keys + prefix = "potential." + # Prefixes of keys to exclude entirely + excluded_prefixes = ["loss."] + + filtered_state_dict = {} + prefixes_removed = set() + + for key, value in state_dict.items(): + # Exclude keys starting with any of the excluded prefixes + if any(key.startswith(ex_prefix) for ex_prefix in excluded_prefixes): + continue # Skip this key entirely + + original_key = key # Keep track of the original key + + # Remove the specified prefix from the key if it exists + if key.startswith(prefix): + key = key[len(prefix) :] + prefixes_removed.add(prefix) + + # change legacy key names + # neighborlist.calculate_distances_and_pairlist.cutoff -> neighborlist.cutoffs + if key == "neighborlist.calculate_distances_and_pairlist.cutoff": + key = "neighborlist.cutoff" + + filtered_state_dict[key] = value + + if prefixes_removed: + log.debug(f"Removed prefixes: {prefixes_removed}") + else: + log.debug("No prefixes found. No modifications to keys in state loading.") + + super().load_state_dict( + filtered_state_dict, + strict=strict, + assign=assign, + ) + + +def setup_potential( + potential_parameter: T_NNP_Parameters, + dataset_statistic: Dict[str, Dict[str, unit.Quantity]] = { + "training_dataset_statistics": { + "per_atom_energy_mean": unit.Quantity(0.0, unit.kilojoule_per_mole), + "per_atom_energy_stddev": unit.Quantity(1.0, unit.kilojoule_per_mole), + } + }, + use_training_mode_neighborlist: bool = False, + potential_seed: Optional[int] = None, + jit: bool = True, + neighborlist_strategy: Optional[str] = None, + verlet_neighborlist_skin: Optional[float] = 0.08, +) -> Potential: + from modelforge.potential import _Implemented_NNPs + from modelforge.potential.utils import remove_units_from_dataset_statistics + from modelforge.utils.misc import seed_random_number + + log.debug(f"potential_seed {potential_seed}") + if potential_seed is not None: + log.info(f"Setting random seed to: {potential_seed}") + seed_random_number(potential_seed) + + model_type = potential_parameter.potential_name + core_network = _Implemented_NNPs.get_neural_network_class(model_type)( + **potential_parameter.core_parameter.model_dump() + ) + + # set unique_pairs based on potential name + only_unique_pairs = potential_parameter.only_unique_pairs + + assert ( + only_unique_pairs is False + if potential_parameter.potential_name.lower() != "ani2x" + else True + ) + + log.debug(f"Only unique pairs: {only_unique_pairs}") + + postprocessing = PostProcessing( + postprocessing_parameter=potential_parameter.postprocessing_parameter.model_dump(), + dataset_statistic=remove_units_from_dataset_statistics(dataset_statistic), + ) + if use_training_mode_neighborlist: + from modelforge.potential.neighbors import NeighborListForTraining + + neighborlist = NeighborListForTraining( + cutoff=potential_parameter.core_parameter.maximum_interaction_radius, + only_unique_pairs=only_unique_pairs, + ) + else: + from modelforge.potential.neighbors import OrthogonalDisplacementFunction + + displacement_function = OrthogonalDisplacementFunction() + + if neighborlist_strategy == "verlet": + from modelforge.potential.neighbors import NeighborlistVerletNsq + + neighborlist = NeighborlistVerletNsq( + cutoff=potential_parameter.core_parameter.maximum_interaction_radius, + displacement_function=displacement_function, + only_unique_pairs=only_unique_pairs, + skin=verlet_neighborlist_skin, + ) + elif neighborlist_strategy == "brute": + from modelforge.potential.neighbors import NeighborlistBruteNsq + + neighborlist = NeighborlistBruteNsq( + cutoff=potential_parameter.core_parameter.maximum_interaction_radius, + displacement_function=displacement_function, + only_unique_pairs=only_unique_pairs, + ) + else: + raise ValueError( + f"Unsupported neighborlist strategy: {neighborlist_strategy}" + ) + + potential = Potential( + core_network, + neighborlist, + postprocessing, + jit=jit, + jit_neighborlist=False if use_training_mode_neighborlist else True, + ) + potential.eval() + return potential + + +from openff.units import unit + + +class NeuralNetworkPotentialFactory: + + @staticmethod + def generate_potential( + *, + potential_parameter: T_NNP_Parameters, + training_parameter: Optional[TrainingParameters] = None, + dataset_parameter: Optional[DatasetParameters] = None, + dataset_statistic: Dict[str, Dict[str, float]] = { + "training_dataset_statistics": { + "per_atom_energy_mean": unit.Quantity(0.0, unit.kilojoule_per_mole), + "per_atom_energy_stddev": unit.Quantity(1.0, unit.kilojoule_per_mole), + } + }, + potential_seed: Optional[int] = None, + use_training_mode_neighborlist: bool = False, + simulation_environment: Literal["PyTorch", "JAX"] = "PyTorch", + jit: bool = True, + inference_neighborlist_strategy: str = "verlet", + verlet_neighborlist_skin: Optional[float] = 0.1, + ) -> Union[Potential, JAXModel, pl.LightningModule]: + """ + Create an instance of a neural network potential for inference. + + Parameters + ---------- + potential_parameter : T_NNP_Parameters] + Parameters specific to the neural network potential. + training_parameter : Optional[TrainingParameters], optional + Parameters for configuring training (default is None). + dataset_parameter : Optional[DatasetParameters], optional + Parameters for configuring the dataset (default is None). + dataset_statistic : Dict[str, Dict[str, float]], optional + Dataset statistics for normalization (default is provided). + potential_seed : Optional[int], optional + Seed for random number generation (default is None). + use_training_mode_neighborlist : bool, optional + Whether to use neighborlist during training mode (default is False). + simulation_environment : Literal["PyTorch", "JAX"], optional + Specify whether to use PyTorch or JAX as the simulation environment + (default is "PyTorch"). + jit : bool, optional + Whether to use JIT compilation (default is True). + inference_neighborlist_strategy : Optional[str], optional + Neighborlist strategy for inference (default is "verlet"). other option is "brute". + verlet_neighborlist_skin : Optional[float], optional + Skin for the Verlet neighborlist (default is 0.1, units nanometers). + Returns + ------- + Union[Potential, JAXModel] + An instantiated neural network potential for training or inference. + """ + + log.debug(f"{training_parameter=}") + log.debug(f"{potential_parameter=}") + log.debug(f"{dataset_parameter=}") + + # obtain model for inference + potential = setup_potential( + potential_parameter=potential_parameter, + dataset_statistic=dataset_statistic, + use_training_mode_neighborlist=use_training_mode_neighborlist, + potential_seed=potential_seed, + jit=jit, + neighborlist_strategy=inference_neighborlist_strategy, + verlet_neighborlist_skin=verlet_neighborlist_skin, + ) + # Disable gradients for model parameters + for param in potential.parameters(): + param.requires_grad = False + # Set model to eval + potential.eval() + + if simulation_environment == "JAX": + # register nnp_input as pytree + from modelforge.utils.io import import_ + + jax = import_("jax") + from modelforge.jax import nnpinput_flatten, nnpinput_unflatten + + # registering NNPInput multiple times will result in a + # ValueError + try: + jax.tree_util.register_pytree_node( + NNPInput, + nnpinput_flatten, + nnpinput_unflatten, + ) + except ValueError: + log.debug("NNPInput already registered as pytree") + pass + return PyTorch2JAXConverter().convert_to_jax_model(potential) + else: + return potential + + @staticmethod + def generate_trainer( + *, + potential_parameter: T_NNP_Parameters, + runtime_parameter: Optional[RuntimeParameters] = None, + training_parameter: Optional[TrainingParameters] = None, + dataset_parameter: Optional[DatasetParameters] = None, + dataset_statistic: Dict[str, Dict[str, float]] = { + "training_dataset_statistics": { + "per_atom_energy_mean": unit.Quantity(0.0, unit.kilojoule_per_mole), + "per_atom_energy_stddev": unit.Quantity(1.0, unit.kilojoule_per_mole), + } + }, + potential_seed: Optional[int] = None, + use_default_dataset_statistic: bool = False, + ) -> "PotentialTrainer": + """ + Create a lightning trainer object to train the neural network potential. + + Parameters + ---------- + potential_parameter : T_NNP_Parameters] + Parameters specific to the neural network potential. + runtime_parameter : Optional[RuntimeParameters], optional + Parameters for configuring the runtime environment (default is + None). + training_parameter : Optional[TrainingParameters], optional + Parameters for configuring training (default is None). + dataset_parameter : Optional[DatasetParameters], optional + Parameters for configuring the dataset (default is None). + dataset_statistic : Dict[str, Dict[str, float]], optional + Dataset statistics for normalization (default is provided). + potential_seed : Optional[int], optional + Seed for random number generation (default is None). + use_default_dataset_statistic : bool, optional + Whether to use default dataset statistics (default is False). + Returns + ------- + PotentialTrainer + An instantiated neural network potential for training. + """ + from modelforge.utils.misc import seed_random_number + from modelforge.train.training import PotentialTrainer + + if potential_seed is not None: + log.info(f"Setting random seed to: {potential_seed}") + seed_random_number(potential_seed) + + log.debug(f"{training_parameter=}") + log.debug(f"{potential_parameter=}") + log.debug(f"{runtime_parameter=}") + log.debug(f"{dataset_parameter=}") + + trainer = PotentialTrainer( + potential_parameter=potential_parameter, + training_parameter=training_parameter, + dataset_parameter=dataset_parameter, + runtime_parameter=runtime_parameter, + potential_seed=potential_seed, + dataset_statistic=dataset_statistic, + use_default_dataset_statistic=use_default_dataset_statistic, + ) + return trainer + + +class PyTorch2JAXConverter: + """ + Wraps a PyTorch neural network potential instance in a Flax module using the + `pytorch2jax` library (https://github.com/subho406/Pytorch2Jax). + The converted model uses dlpack to convert between Pytorch and Jax tensors + in-memory and executes Pytorch backend inside Jax wrapped functions. + The wrapped modules are compatible with Jax backward-mode autodiff. + """ + + def convert_to_jax_model( + self, + nnp_instance: Potential, + ) -> JAXModel: + """ + Convert a PyTorch neural network instance to a JAX model. + + Parameters + ---------- + nnp_instance : + The PyTorch neural network instance to be converted. + + Returns + ------- + JAXModel + A JAX model containing the converted neural network function, parameters, and buffers. + """ + + jax_fn, params, buffers = self._convert_pytnn_to_jax(nnp_instance) + return JAXModel(jax_fn, params, buffers, nnp_instance.__class__.__name__) + + @staticmethod + def _convert_pytnn_to_jax( + nnp_instance: Potential, + ) -> Tuple[Callable, np.ndarray, np.ndarray]: + """Internal method to convert PyTorch neural network parameters and buffers to JAX format. + + Parameters + ---------- + nnp_instance : Any + The PyTorch neural network instance. + + Returns + ------- + Tuple[Callable, Any, Any] + A tuple containing the JAX function, parameters, and buffers. + """ + + # make sure + from modelforge.utils.io import import_ + + jax = import_("jax") + # use the wrapper to check if pytorch2jax is in the environment + + custom_vjp = import_("jax").custom_vjp + + # from jax import custom_vjp + convert_to_jax = import_("pytorch2jax").pytorch2jax.convert_to_jax + convert_to_pyt = import_("pytorch2jax").pytorch2jax.convert_to_pyt + # from pytorch2jax.pytorch2jax import convert_to_jax, convert_to_pyt + + import functorch + from functorch import make_functional_with_buffers + + # Convert the PyTorch model to a functional representation and extract the model function and parameters + model_fn, model_params, model_buffer = make_functional_with_buffers( + nnp_instance + ) + + # Convert the model parameters from PyTorch to JAX representations + model_params = jax.tree_map(convert_to_jax, model_params) + # Convert the model buffer from PyTorch to JAX representations + model_buffer = jax.tree_map(convert_to_jax, model_buffer) + + # Define the apply function using a custom VJP + @custom_vjp + def apply(params, *args, **kwargs): + # Convert the input data from JAX to PyTorch + params, args, kwargs = map( + lambda x: jax.tree_map(convert_to_pyt, x), (params, args, kwargs) + ) + # Apply the model function to the input data + out = model_fn(params, *args, **kwargs) + # Convert the output data from PyTorch to JAX + out = jax.tree_map(convert_to_jax, out) + return out + + # Define the forward and backward passes for the VJP + def apply_fwd(params, *args, **kwargs): + return apply(params, *args, **kwargs), (params, args, kwargs) + + def apply_bwd(res, grads): + params, args, kwargs = res + params, args, kwargs = map( + lambda x: jax.tree_map(convert_to_pyt, x), (params, args, kwargs) + ) + grads = jax.tree_map(convert_to_pyt, grads) + # Compute the gradients using the model function and convert them + # from JAX to PyTorch representations + grads = functorch.vjp(model_fn, params, *args, **kwargs)[1](grads) + return jax.tree_map(convert_to_jax, grads) + + apply.defvjp(apply_fwd, apply_bwd) + + # Return the apply function and the converted model parameters + return apply, model_params, model_buffer + + +def load_inference_model_from_checkpoint( + checkpoint_path: str, +) -> Union[Potential, JAXModel]: + """ + Creates an inference model from a checkpoint file. + It loads the checkpoint file, extracts the hyperparameters, and creates the model in inference mode. + + Parameters + ---------- + checkpoint_path : str + The path to the checkpoint file. + """ + import torch + + # Load the checkpoint + checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu")) + + # Extract hyperparameters + hyperparams = checkpoint["hyper_parameters"] + potential_parameter = hyperparams["potential_parameter"] + dataset_statistic = hyperparams.get("dataset_statistic", None) + potential_seed = hyperparams.get("potential_seed", None) + + # Create the model in inference mode + potential = NeuralNetworkPotentialFactory.generate_potential( + potential_parameter=potential_parameter, + dataset_statistic=dataset_statistic, + potential_seed=potential_seed, + ) + + # Load the state dict into the model + potential.load_state_dict(checkpoint["state_dict"]) + + # Return the model + return potential diff --git a/modelforge/potential/processing.py b/modelforge/potential/processing.py index 50848f5f..ee664a62 100644 --- a/modelforge/potential/processing.py +++ b/modelforge/potential/processing.py @@ -1,7 +1,15 @@ +""" +This module contains utility functions and classes for processing the output of the potential model. +""" + +from dataclasses import dataclass, field +from typing import Dict, Iterator, Union + import torch -from typing import Dict from openff.units import unit +from modelforge.dataset.utils import _ATOMIC_NUMBER_TO_ELEMENT + def load_atomic_self_energies(path: str) -> Dict[str, unit.Quantity]: """ @@ -64,11 +72,7 @@ class FromAtomToMoleculeReduction(torch.nn.Module): def __init__( self, - per_atom_property_name: str, - index_name: str, - output_name: str, reduction_mode: str = "sum", - keep_per_atom_property: bool = False, ): """ Initializes the per-atom property readout_operation module. @@ -88,49 +92,39 @@ def __init__( """ super().__init__() self.reduction_mode = reduction_mode - self.per_atom_property_name = per_atom_property_name - self.output_name = output_name - self.index_name = index_name - self.keep_per_atom_property = keep_per_atom_property - def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def forward( + self, indices: torch.Tensor, per_atom_property: torch.Tensor + ) -> torch.Tensor: """ Forward pass of the module. Parameters ---------- data : Dict[str, torch.Tensor] - The input data dictionary containing the per-atom property and index. + The input data dictionary containing the per-atom property and + index. Returns ------- Dict[str, torch.Tensor] The output data dictionary containing the per-molecule property. """ - indices = data[self.index_name].to(torch.int64) - per_atom_property = data[self.per_atom_property_name] + # Perform scatter add operation for atoms belonging to the same molecule - property_per_molecule_zeros = torch.zeros( - len(indices.unique()), + nr_of_molecules = torch.unique(indices).unsqueeze(1) + per_system_property = torch.zeros_like( + nr_of_molecules, dtype=per_atom_property.dtype, device=per_atom_property.device, ) - property_per_molecule = property_per_molecule_zeros.scatter_reduce( - 0, indices, per_atom_property, reduce=self.reduction_mode + return per_system_property.scatter_reduce( + 0, + indices.long().unsqueeze(1), + per_atom_property, + reduce=self.reduction_mode, ) - data[self.output_name] = property_per_molecule - if self.keep_per_atom_property is False: - del data[self.per_atom_property_name] - - return data - - -from dataclasses import dataclass, field -from typing import Dict, Iterator - -from openff.units import unit -from modelforge.dataset.utils import _ATOMIC_NUMBER_TO_ELEMENT @dataclass @@ -226,7 +220,9 @@ class ScaleValues(torch.nn.Module): """ def __init__( - self, mean: float, stddev: float, property: str, output_name: str + self, + mean: float, + stddev: float, ) -> None: """ Rescales values using the provided mean and standard deviation. @@ -246,10 +242,8 @@ def __init__( super().__init__() self.register_buffer("mean", torch.tensor([mean])) self.register_buffer("stddev", torch.tensor([stddev])) - self.property = property - self.output_name = output_name - def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def forward(self, data: torch.Tensor) -> torch.Tensor: """ Rescales values using the provided mean and standard deviation. @@ -263,11 +257,178 @@ def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: Dict[str, torch.Tensor] The output data dictionary containing the rescaled values. """ - data[self.output_name] = data[self.property] * self.stddev + self.mean + return data * self.stddev + self.mean + + +def default_charge_conservation( + per_atom_charge: torch.Tensor, + per_system_total_charge: torch.Tensor, + mol_indices: torch.Tensor, +) -> torch.Tensor: + """ + Adjusts partial atomic charges so that the sum of charges in each molecule + matches the desired total charge. + + This method is based on equation 14 from the PhysNet paper. + + Parameters + ---------- + partial_charges : torch.Tensor + Tensor of partial charges for all atoms in all molecules. + per_system_total_charge : torch.Tensor + Tensor of desired total charges for each molecule. + mol_indices : torch.Tensor + Tensor of integers indicating which molecule each atom belongs to. + + Returns + ------- + torch.Tensor + Tensor of corrected partial charges. + """ + # Calculate the sum of partial charges for each molecule + predicted_per_system_total_charge = torch.zeros_like( + per_system_total_charge, dtype=per_atom_charge.dtype + ).scatter_add_( + 0, + mol_indices.long().unsqueeze(1), + per_atom_charge, + ) + + # Calculate the number of atoms in each molecule + num_atoms_per_system = mol_indices.bincount( + minlength=per_system_total_charge.size(0) + ) + + # Calculate the correction factor for each molecule + correction_factors = ( + per_system_total_charge - predicted_per_system_total_charge + ) / num_atoms_per_system.unsqueeze(1) + + # Apply the correction to each atom's charge + per_atom_charge_corrected = per_atom_charge + correction_factors[mol_indices] + + return per_atom_charge_corrected + + +class ChargeConservation(torch.nn.Module): + def __init__(self, method="default"): + """ + Module to enforce charge conservation on partial atomic charges. + + Parameters + ---------- + method : str, optional, default='default' + The method to use for charge conservation. Currently, only 'default' + is supported. + + Methods + ------- + forward(data) + Applies charge conservation to the partial charges in the provided + data dictionary. + """ + + super().__init__() + self.method = method + if self.method == "default": + self.correct_partial_charges = default_charge_conservation + else: + raise ValueError(f"Unknown charge conservation method: {self.method}") + + def forward( + self, + data: Dict[str, torch.Tensor], + ): + """ + Apply charge conservation to partial charges in the data dictionary. + + Parameters + ---------- + data : Dict[str, torch.Tensor] + Dictionary containing the following keys: + - "per_atom_charge": + Tensor of partial charges for all atoms in the batch. + - "per_system_total_charge": + Tensor of desired total charges for each + molecule. + - "atomic_subsystem_indices": + Tensor indicating which molecule each atom belongs to. + + Returns + ------- + Dict[str, torch.Tensor] + Updated data dictionary with the key "per_atom_charge_corrected" + added, containing the corrected per-atom charges. + """ + data["per_atom_charge_uncorrected"] = data["per_atom_charge"] + data["per_atom_charge"] = self.correct_partial_charges( + data["per_atom_charge"], + data["per_system_total_charge"], + data["atomic_subsystem_indices"], + ) return data -from typing import Union +class PerAtomEnergy(torch.nn.Module): + + def __init__( + self, + per_atom_energy: Dict[str, bool], + dataset_statistics: Dict[str, float], + ): + """ + Process per atom energies. Depending on what has been requested in the per_atom_energy dictionary, the per atom energies are normalized and/or reduced to per system energies. + Parameters + ---------- + per_atom_energy : Dict[str, bool] + A dictionary containing the per atom energy processing options. + dataset_statistics : Dict[str, float] + + """ + super().__init__() + + if per_atom_energy.get("normalize"): + scale = ScaleValues( + dataset_statistics["per_atom_energy_mean"], + dataset_statistics["per_atom_energy_stddev"], + ) + else: + scale = ScaleValues(0.0, 1.0) + + self.scale = scale + + if per_atom_energy.get("from_atom_to_system_reduction"): + reduction = FromAtomToMoleculeReduction() + + self.reduction = reduction + + def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + per_atom_property, indices = ( + data["per_atom_energy"], + data["atomic_subsystem_indices"], + ) + scaled_values = self.scale(per_atom_property) + per_system_energy = self.reduction(indices, scaled_values) + + data["per_system_energy"] = per_system_energy + data["per_atom_energy"] = data["per_atom_energy"].detach() + + return data + + +class PerAtomCharge(torch.nn.Module): + + def __init__(self, per_atom_charge: Dict[str, bool]): + super().__init__() + from torch import nn + + if per_atom_charge["conserve"] == True: + self.conserve = ChargeConservation(per_atom_charge["conserve_strategy"]) + else: + self.conserve = nn.Identity() + + def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + return self.conserve(data) class CalculateAtomicSelfEnergy(torch.nn.Module): @@ -298,7 +459,7 @@ def __init__( } self.atomic_self_energies = AtomicSelfEnergies(atomic_self_energies) - def forward(self, data: Dict[str, torch.Tensor]) -> torch.Tensor: + def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Calculates the molecular self energy. @@ -330,3 +491,87 @@ def forward(self, data: Dict[str, torch.Tensor]) -> torch.Tensor: data["ase_tensor"] = ase_tensor return data + + +class CoulombPotential(torch.nn.Module): + def __init__(self, cutoff: float): + """ + Computes the long-range electrostatic energy for a molecular system + based on predicted partial charges and pairwise distances between atoms. + + The implementation follows the methodology described in the PhysNet + paper, using a cutoff function to handle long-range interactions. + + Parameters + ---------- + cutoff : float + The cutoff distance beyond which the interactions are not + considered in nanometer. + + Attributes + ---------- + strategy : str + The strategy for computing long-range interactions. + cutoff_function : nn.Module + The cutoff function applied to the pairwise distances. + """ + super().__init__() + from .representation import PhysNetAttenuationFunction + + self.cutoff_function = PhysNetAttenuationFunction(cutoff) + + def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Forward pass to compute the long-range electrostatic energy. + + This function calculates the long-range electrostatic energy by considering + pairwise Coulomb interactions between atoms, applying a cutoff function to + handle long-range interactions. + + Parameters + ---------- + data : Dict[str, torch.Tensor] + Input data containing the following keys: + - 'per_atom_charge': Tensor of shape (N,) with partial charges for each atom. + - 'atomic_subsystem_indices': Tensor indicating the molecule each atom belongs to. + - 'pairwise_properties': Object containing pairwise distances and indices. + + Returns + ------- + Dict[str, torch.Tensor] + The input data dictionary with an additional key 'long_range_electrostatic_energy' + containing the computed long-range electrostatic energy. + """ + mol_indices = data["atomic_subsystem_indices"] + idx_i, idx_j = data["pair_indices"] + + # only unique paris + unique_pairs_mask = idx_i < idx_j + idx_i = idx_i[unique_pairs_mask] + idx_j = idx_j[unique_pairs_mask] + + # mask pairwise properties + pairwise_distances = data["d_ij"][unique_pairs_mask] + per_atom_charge = data["per_atom_charge"] + + # Initialize the long-range electrostatic energy + electrostatic_energy = torch.zeros_like(data["per_system_energy"]) + + # Apply the cutoff function to pairwise distances + phi_2r = self.cutoff_function(2 * pairwise_distances) + chi_r = phi_2r * (1 / torch.sqrt(pairwise_distances**2 + 1)) + ( + 1 - phi_2r + ) * (1 / pairwise_distances) + + # Compute the Coulomb interaction term + coulomb_interactions = (per_atom_charge[idx_i] * per_atom_charge[idx_j]) * chi_r + + # Sum over all interactions for each molecule + data["electrostatic_energy"] = ( + electrostatic_energy.scatter_add_( + 0, mol_indices.long().unsqueeze(1), coulomb_interactions + ) + * 138.96 + ) # in kj/mol nm + + return data diff --git a/modelforge/potential/representation.py b/modelforge/potential/representation.py new file mode 100644 index 00000000..8edee7df --- /dev/null +++ b/modelforge/potential/representation.py @@ -0,0 +1,736 @@ +import torch +from typing import Optional +from torch import nn +import numpy as np + + +class PhysNetAttenuationFunction(nn.Module): + def __init__(self, cutoff: float): + """ + Initialize the PhysNet attenuation function. + + Parameters + ---------- + cutoff : unit.Quantity + The cutoff distance. + """ + super().__init__() + self.register_buffer("cutoff", torch.tensor([cutoff])) + + def forward(self, d_ij: torch.Tensor): + + return torch.clamp( + ( + 1 + - 6 * torch.pow((d_ij / self.cutoff), 5) + + 15 * torch.pow((d_ij / self.cutoff), 4) + - 10 * torch.pow((d_ij / self.cutoff), 3) + ), + min=0, + ) + + +class CosineAttenuationFunction(nn.Module): + def __init__(self, cutoff: float): + """ + Behler-style cosine cutoff module. This anneals the signal smoothly to zero at the cutoff distance. + + NOTE: The cutoff is converted to nanometer and the input MUST be in nanomter too. + + Parameters: + ----------- + cutoff: unit.Quantity + The cutoff distance. + + """ + super().__init__() + self.register_buffer("cutoff", torch.tensor([cutoff])) + + def forward(self, d_ij: torch.Tensor): + """ + Compute the cosine cutoff for a distance tensor. + NOTE: the cutoff function doesn't care about units as long as they are consisten, + + Parameters + ----------- + d_ij : Tensor + Pairwise distance tensor in nanometer. Shape: [n_pairs, 1] + + Returns + -------- + Tensor + Cosine cutoff tensor. Shape: [n_pairs, 1] + """ + # Compute values of cutoff function + input_cut = 0.5 * ( + torch.cos(d_ij * np.pi / self.cutoff) + 1.0 + ) # NOTE: ANI adds 0.5 instead of 1. + # Remove contributions beyond the cutoff radius + input_cut *= (d_ij < self.cutoff).float() + return input_cut + + +class AngularSymmetryFunction(nn.Module): + """ + Initialize AngularSymmetryFunction module. + + """ + + def __init__( + self, + maximum_interaction_radius: float, + min_distance: float, + number_of_gaussians_for_asf: int = 8, + angle_sections: int = 4, + trainable: bool = False, + dtype: Optional[torch.dtype] = None, + ) -> None: + """ + Parameters + ---- + number_of_gaussian: Number of gaussian functions to use for angular symmetry function. + angular_cutoff: Cutoff distance for angular symmetry function. + angular_start: Starting distance for angular symmetry function. + ani_style: Whether to use ANI symmetry function style. + """ + + super().__init__() + from loguru import logger as log + + self.number_of_gaussians_asf = number_of_gaussians_for_asf + self.angular_cutoff = maximum_interaction_radius + self.cosine_cutoff = CosineAttenuationFunction(self.angular_cutoff) + _unitless_angular_cutoff = maximum_interaction_radius + self.angular_start = min_distance + _unitless_angular_start = min_distance + + # save constants + EtaA = angular_eta = 12.5 * 100 # FIXME hardcoded eta + Zeta = 14.1000 # FIXME hardcoded zeta + + if trainable: + self.EtaA = torch.tensor([EtaA], dtype=dtype) + self.Zeta = torch.tensor([Zeta], dtype=dtype) + self.Rca = torch.tensor([_unitless_angular_cutoff], dtype=dtype) + + else: + self.register_buffer("EtaA", torch.tensor([EtaA], dtype=dtype)) + self.register_buffer("Zeta", torch.tensor([Zeta], dtype=dtype)) + self.register_buffer( + "Rca", torch.tensor([_unitless_angular_cutoff], dtype=dtype) + ) + + # =============== + # # calculate shifts + # =============== + import math + + # ShfZ + angle_start = math.pi / (2 * angle_sections) + ShfZ = (torch.linspace(0, math.pi, angle_sections + 1) + angle_start)[:-1] + # ShfA + ShfA = torch.linspace( + _unitless_angular_start, + _unitless_angular_cutoff, + number_of_gaussians_for_asf + 1, + )[:-1] + # register shifts + if trainable: + self.ShfZ = ShfZ + self.ShfA = ShfA + else: + self.register_buffer("ShfZ", ShfZ) + self.register_buffer("ShfA", ShfA) + + # The length of angular subaev of a single species + self.angular_sublength = self.ShfA.numel() * self.ShfZ.numel() + + def forward(self, r_ij: torch.Tensor) -> torch.Tensor: + # calculate the angular sub aev + sub_aev = self.compute_angular_sub_aev(r_ij) + return sub_aev + + def compute_angular_sub_aev(self, vectors12: torch.Tensor) -> torch.Tensor: + """ + Compute the angular subAEV terms of the center atom given neighbor + pairs. + + This correspond to equation (4) in the ANI paper. This function just + compute the terms. The sum in the equation is not computed. + + Parameters + ---------- + vectors12: torch.Tensor + Pairwise distance vectors. Shape: [2, n_pairs, 3] + Returns + ------- + torch.Tensor + Angular subAEV terms. Shape: [n_pairs, ShfZ_size * ShfA_size] + + """ + # vectors12 has shape: (2, n_pairs, 3) + distances12 = vectors12.norm(p=2, dim=-1) # Shape: (2, n_pairs) + distances_sum = distances12.sum(dim=0) / 2 # Shape: (n_pairs,) + fcj12 = self.cosine_cutoff(distances12) # Shape: (2, n_pairs) + fcj12_prod = fcj12.prod(dim=0) # Shape: (n_pairs,) + + # cos_angles: (n_pairs,) + + cos_angles = 0.95 * torch.nn.functional.cosine_similarity( + vectors12[0], vectors12[1], dim=-1 + ) + + angles = torch.acos(cos_angles) # Shape: (n_pairs,) + + # Prepare shifts for broadcasting + angles = angles.unsqueeze(-1) # Shape: (n_pairs, 1) + distances_sum = distances_sum.unsqueeze(-1) # Shape: (n_pairs, 1) + + # Compute factor1 + delta_angles = angles - self.ShfZ.view(1, -1) # Shape: (n_pairs, ShfZ_size) + factor1 = ( + (1 + torch.cos(delta_angles)) / 2 + ) ** self.Zeta # Shape: (n_pairs, ShfZ_size) + + # Compute factor2 + delta_distances = distances_sum - self.ShfA.view( + 1, -1 + ) # Shape: (n_pairs, ShfA_size) + factor2 = torch.exp( + -self.EtaA * delta_distances**2 + ) # Shape: (n_pairs, ShfA_size) + + # Compute the outer product of factor1 and factor2 efficiently + # fcj12_prod: (n_pairs, 1, 1) + fcj12_prod = fcj12_prod.unsqueeze(-1).unsqueeze(-1) # Shape: (n_pairs, 1, 1) + + # factor1: (n_pairs, 1, ShfZ_size) + factor1 = factor1.unsqueeze(-2) + # factor2: (n_pairs, ShfA_size, 1) + factor2 = factor2.unsqueeze(-1) + + # Compute ret: (n_pairs, ShfZ_size, ShfA_size) + ret = 2 * factor1 * factor2 * fcj12_prod + + # Flatten the last two dimensions to get the final subAEV + # ret: (n_pairs, ShfZ_size * ShfA_size) + ret = ret.reshape(distances12.size(dim=1), -1) + + return ret + + +import math +from abc import ABC, abstractmethod + +from torch.nn import functional + + +class RadialBasisFunctionCore(nn.Module, ABC): + + def __init__(self, number_of_radial_basis_functions): + super().__init__() + self.number_of_radial_basis_functions = number_of_radial_basis_functions + + @abstractmethod + def forward(self, nondimensionalized_distances: torch.Tensor) -> torch.Tensor: + """ + Parameters + --------- + nondimensionalized_distances: torch.Tensor, shape [number_of_pairs, number_of_radial_basis_functions] + Nondimensional quantities that depend on pairwise distances. + + Returns + --------- + torch.Tensor, shape [number_of_pairs, number_of_radial_basis_functions] + """ + pass + + +class GaussianRadialBasisFunctionCore(RadialBasisFunctionCore): + + def forward(self, nondimensionalized_distances: torch.Tensor) -> torch.Tensor: + assert nondimensionalized_distances.ndim == 2 + assert ( + nondimensionalized_distances.shape[1] + == self.number_of_radial_basis_functions + ) + + return torch.exp(-(nondimensionalized_distances**2)) + + +class RadialBasisFunction(nn.Module, ABC): + + def __init__( + self, + radial_basis_function: RadialBasisFunctionCore, + dtype: torch.dtype, + prefactor: float = 1, + trainable_prefactor: bool = False, + ): + super().__init__() + if trainable_prefactor: + self.prefactor = nn.Parameter(torch.tensor([prefactor], dtype=dtype)) + else: + self.register_buffer("prefactor", torch.tensor([prefactor], dtype=dtype)) + self.radial_basis_function = radial_basis_function + + @abstractmethod + def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: + """ + Parameters + --------- + distances: torch.Tensor, shape [number_of_pairs, 1] + Distances between atoms in each pair in nanometers. + + Returns + --------- + torch.Tensor, shape [number_of_pairs, number_of_radial_basis_functions] + Nondimensional quantities computed from the distances. + """ + pass + + def forward(self, distances: torch.Tensor) -> torch.Tensor: + """ + The input distances have implicit units of nanometers by the convention + of modelforge. This function applies nondimensionalization + transformations on the distances and passes the dimensionless result to + RadialBasisFunctionCore. There can be several nondimsionalization + transformations, corresponding to each element along the + number_of_radial_basis_functions axis in the output. + + Parameters + --------- + distances: torch.Tensor, shape [number_of_pairs, 1] + Distances between atoms in each pair in nanometers. + + Returns + --------- + torch.Tensor, shape [number_of_pairs, number_of_radial_basis_functions] + Output of radial basis functions. + """ + nondimensionalized_distances = self.nondimensionalize_distances(distances) + return self.prefactor * self.radial_basis_function(nondimensionalized_distances) + + +class GaussianRadialBasisFunctionWithScaling(RadialBasisFunction): + """ + Shifts inputs by a set of centers and scales by a set of scale factors before passing into the standard Gaussian. + """ + + def __init__( + self, + number_of_radial_basis_functions: int, + max_distance: float, + min_distance: float = 0.0, + dtype: torch.dtype = torch.float32, + prefactor: float = 1.0, + trainable_prefactor: bool = False, + trainable_centers_and_scale_factors: bool = False, + ): + """ + Parameters + --------- + number_of_radial_basis_functions: int + Number of radial basis functions to use. + max_distance: unit.Quantity + Maximum distance to consider for symmetry functions. + min_distance: unit.Quantity + Minimum distance to consider. + dtype: torch.dtype, default None + Data type for computations. + prefactor: float + Scalar factor by which to multiply output of radial basis functions. + trainable_prefactor: bool, default False + Whether prefactor is trainable + trainable_centers_and_scale_factors: bool, default False + Whether centers and scale factors are trainable. + """ + + super().__init__( + GaussianRadialBasisFunctionCore(number_of_radial_basis_functions), + dtype, + prefactor, + trainable_prefactor, + ) + self.number_of_radial_basis_functions = number_of_radial_basis_functions + self.dtype = dtype + self.trainable_centers_and_scale_factors = trainable_centers_and_scale_factors + # convert to nanometer + _max_distance_in_nanometer = max_distance + _min_distance_in_nanometer = min_distance + + # calculate radial basis centers + radial_basis_centers = self.calculate_radial_basis_centers( + self.number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + self.dtype, + ) + # calculate scale factors + radial_scale_factor = self.calculate_radial_scale_factor( + self.number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + self.dtype, + ) + + # either add as parameters or register buffers + if self.trainable_centers_and_scale_factors: + self.radial_basis_centers = radial_basis_centers + self.radial_scale_factor = radial_scale_factor + else: + self.register_buffer("radial_basis_centers", radial_basis_centers) + self.register_buffer("radial_scale_factor", radial_scale_factor) + + @staticmethod + @abstractmethod + def calculate_radial_basis_centers( + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, + ): + """ + NOTE: centers have units of nanometers + """ + pass + + @staticmethod + @abstractmethod + def calculate_radial_scale_factor( + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, + ): + """ + NOTE: radial scale factors have units of nanometers + """ + pass + + def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: + # Here, self.radial_scale_factor is interpreted as sqrt(2) times the standard deviation of the Gaussian. + diff = distances - self.radial_basis_centers + return diff / self.radial_scale_factor + + +class SchnetRadialBasisFunction(GaussianRadialBasisFunctionWithScaling): + """ + Implementation of the radial basis function as used by the SchNet neural network + """ + + def __init__( + self, + number_of_radial_basis_functions: int, + max_distance: float, + min_distance: float = 0.0, + dtype: torch.dtype = torch.float32, + trainable_centers_and_scale_factors: bool = False, + ): + """ + Parameters + --------- + number_of_radial_basis_functions: int + Number of radial basis functions to use. + max_distance: unit.Quantity + Maximum distance to consider for symmetry functions. + min_distance: unit.Quantity + Minimum distance to consider. + dtype: torch.dtype, default None + Data type for computations. + trainable_centers_and_scale_factors: bool, default False + Whether centers and scale factors are trainable. + """ + super().__init__( + number_of_radial_basis_functions, + max_distance, + min_distance, + dtype, + trainable_prefactor=False, + trainable_centers_and_scale_factors=trainable_centers_and_scale_factors, + ) + + @staticmethod + def calculate_radial_basis_centers( + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, + ): + return torch.linspace( + _min_distance_in_nanometer, + _max_distance_in_nanometer, + number_of_radial_basis_functions, + dtype=dtype, + ) + + @staticmethod + def calculate_radial_scale_factor( + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, + ): + scale_factors = torch.linspace( + _min_distance_in_nanometer, + _max_distance_in_nanometer, + number_of_radial_basis_functions, + dtype=dtype, + ) + + widths = torch.abs(scale_factors[1] - scale_factors[0]) * torch.ones_like( + scale_factors + ) + + scale_factors = math.sqrt(2) * widths + return scale_factors + + +class AniRadialBasisFunction(GaussianRadialBasisFunctionWithScaling): + """ + Implementation of the radial basis function as used by the ANI neural network + """ + + def __init__( + self, + number_of_radial_basis_functions, + max_distance: float, + min_distance: float = 0.0, + dtype: torch.dtype = torch.float32, + trainable_centers_and_scale_factors: bool = False, + ): + """ + Parameters + --------- + number_of_radial_basis_functions: int + Number of radial basis functions to use. + max_distance: float + Maximum distance to consider for symmetry functions. + min_distance: float + Minimum distance to consider. + dtype: torch.dtype, default torch.float32 + Data type for computations. + trainable_centers_and_scale_factors: bool, default False + Whether centers and scale factors are trainable. + """ + super().__init__( + number_of_radial_basis_functions, + max_distance, + min_distance, + dtype, + prefactor=0.25, + trainable_prefactor=False, + trainable_centers_and_scale_factors=trainable_centers_and_scale_factors, + ) + + @staticmethod + def calculate_radial_basis_centers( + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, + ): + centers = torch.linspace( + _min_distance_in_nanometer, + _max_distance_in_nanometer, + number_of_radial_basis_functions + 1, + dtype=dtype, + )[:-1] + return centers + + @staticmethod + def calculate_radial_scale_factor( + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, + ): + # ANI uses a predefined scaling factor + scale_factors = torch.full( + (number_of_radial_basis_functions,), (19.7 * 100) ** -0.5 + ) + return scale_factors + + +class PhysNetRadialBasisFunction(RadialBasisFunction): + """ + Implementation of the radial basis function as used by the PysNet neural network + """ + + def __init__( + self, + number_of_radial_basis_functions: int, + max_distance: float, + min_distance: float = 0.0, + alpha: float = 0.1, + dtype: torch.dtype = torch.float32, + trainable_centers_and_scale_factors: bool = False, + ): + """ + Parameters + ---------- + number_of_radial_basis_functions : int + Number of radial basis functions to use. + max_distance : float + Maximum distance to consider for symmetry functions. + min_distance : float + Minimum distance to consider, by default 0.0 * unit.nanometer. + alpha: float + Scale factor used to nondimensionalize the input to all exp calls. The PhysNet paper implicitly divides by 1 + Angstrom within exponentials. Note that this is distinct from the unitless scale factors used outside the + exp but within the Gaussian. + dtype : torch.dtype, optional + Data type for computations, by default torch.float32. + trainable_centers_and_scale_factors : bool, optional + Whether centers and scale factors are trainable, by default False. + """ + + super().__init__( + GaussianRadialBasisFunctionCore(number_of_radial_basis_functions), + trainable_prefactor=False, + dtype=dtype, + ) + self._min_distance_in_nanometer = min_distance + self._alpha_in_nanometer = alpha + radial_basis_centers = self.calculate_radial_basis_centers( + number_of_radial_basis_functions, + max_distance, + min_distance, + alpha, + dtype, + ) + # calculate scale factors + radial_scale_factor = self.calculate_radial_scale_factor( + number_of_radial_basis_functions, + max_distance, + min_distance, + alpha, + dtype, + ) + + if trainable_centers_and_scale_factors: + self.radial_basis_centers = radial_basis_centers + self.radial_scale_factor = radial_scale_factor + else: + self.register_buffer("radial_basis_centers", radial_basis_centers) + self.register_buffer("radial_scale_factor", radial_scale_factor) + + @staticmethod + def calculate_radial_basis_centers( + number_of_radial_basis_functions, + max_distance, + min_distance, + alpha, + dtype, + ): + # initialize centers according to the default values in PhysNet (see + # mu_k in Figure 2 caption of + # https://pubs.acs.org/doi/10.1021/acs.jctc.9b00181) NOTE: Unlike + # GaussianRadialBasisFunctionWithScaling, the centers are unitless. + + start_value = torch.exp( + torch.scalar_tensor( + ((-max_distance + min_distance) / alpha), + dtype=dtype, + ) + ) + centers = torch.linspace( + start_value, 1, number_of_radial_basis_functions, dtype=dtype + ) + return centers + + @staticmethod + def calculate_radial_scale_factor( + number_of_radial_basis_functions, + max_distance, + min_distance, + alpha, + dtype, + ): + # initialize according to the default values in PhysNet (see beta_k in + # Figure 2 caption) NOTES: + # - Unlike GaussianRadialBasisFunctionWithScaling, the scale factors are + # unitless. + # - Each element of radial_square_factor here is the reciprocal of the + # square root of beta_k in the Eq. 7 of the PhysNet paper. This way, it + # is consistent with the sqrt(2) * standard deviation interpretation of + # radial_scale_factor in GaussianRadialBasisFunctionWithScaling + return torch.full( + (number_of_radial_basis_functions,), + (2 * (1 - math.exp(((-max_distance + min_distance) / alpha)))) + / number_of_radial_basis_functions, + dtype=dtype, + ) + + def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: + # Transformation within the outer exp of PhysNet Eq. 7 + # NOTE: the PhysNet paper implicitly multiplies by 1/Angstrom within the inner exp but distances are in + # nanometers, so we multiply by 10/nanometer + + return ( + torch.exp( + (-distances + self._min_distance_in_nanometer) + / self._alpha_in_nanometer + ) + - self.radial_basis_centers + ) / self.radial_scale_factor + + +class TensorNetRadialBasisFunction(PhysNetRadialBasisFunction): + """ + The only difference from PhysNetRadialBasisFunction is that alpha is set + to 1 angstrom only for the purpose of unitless calculations. + """ + + @staticmethod + def calculate_radial_basis_centers( + number_of_radial_basis_functions, + max_distance, + min_distance, + alpha, + dtype, + ): + alpha = 0.1 + start_value = torch.exp( + torch.scalar_tensor( + ((-max_distance + min_distance) / alpha), + dtype=dtype, + ) + ) + centers = torch.linspace( + start_value, 1, number_of_radial_basis_functions, dtype=dtype + ) + return centers + + @staticmethod + def calculate_radial_scale_factor( + number_of_radial_basis_functions, + max_distance, + min_distance, + alpha, + dtype, + ): + alpha = 0.1 + start_value = torch.exp( + torch.scalar_tensor(((-max_distance + min_distance) / alpha)) + ) + radial_scale_factor = torch.full( + (number_of_radial_basis_functions,), + 2 / number_of_radial_basis_functions * (1 - start_value), + dtype=dtype, + ) + + return radial_scale_factor + + def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: + # Transformation within the outer exp of PhysNet Eq. 7 NOTE: the PhysNet + # paper implicitly multiplies by 1/Angstrom within the inner exp but + # distances are in nanometers, so we multiply by 10/nanometer + + return ( + torch.exp( + (-distances + self._min_distance_in_nanometer) + / self._alpha_in_nanometer + ) + - self.radial_basis_centers + ) / self.radial_scale_factor diff --git a/modelforge/potential/sake.py b/modelforge/potential/sake.py index 4e0804c5..83a0d37b 100644 --- a/modelforge/potential/sake.py +++ b/modelforge/potential/sake.py @@ -1,179 +1,178 @@ -from dataclasses import dataclass +""" +SAKE - Spatial Attention Kinetic Networks with E(n) Equivariance +""" -import torch.nn as nn -from loguru import logger as log from typing import Dict, Tuple -from openff.units import unit -from .models import InputPreparation, NNPInput, BaseNetwork, CoreNetwork - -from .models import PairListOutputs -from .utils import ( - Dense, - scatter_softmax, - PhysNetRadialBasisFunction, -) -from modelforge.dataset.dataset import NNPInput + import torch -import torch.nn.functional as F +import torch.nn as nn +from loguru import logger as log +from modelforge.utils.prop import NNPInput +from modelforge.potential.neighbors import PairlistData -@dataclass -class SAKENeuralNetworkInput: +from .utils import DenseWithCustomDist, scatter_softmax +from .representation import PhysNetRadialBasisFunction + + +class MultiplySigmoid(nn.Module): """ - A dataclass designed to structure the inputs for SAKE neural network potentials, ensuring - an efficient and structured representation of atomic systems for energy computation and - property prediction within the SAKE framework. - - Attributes - ---------- - atomic_numbers : torch.Tensor - Atomic numbers for each atom in the system(s). Shape: [num_atoms]. - positions : torch.Tensor - XYZ coordinates of each atom. Shape: [num_atoms, 3]. - atomic_subsystem_indices : torch.Tensor - Maps each atom to its respective subsystem or molecule, useful for systems with multiple - molecules. Shape: [num_atoms]. - pair_indices : torch.Tensor - Indicates indices of atom pairs, essential for computing pairwise features. Shape: [2, num_pairs]. - number_of_atoms : int - Total number of atoms in the batch, facilitating batch-wise operations. - atomic_embedding : torch.Tensor - Embeddings or features for each atom, potentially derived from atomic numbers or learned. Shape: [num_atoms, embedding_dim]. - - Notes - ----- - The `SAKENeuralNetworkInput` dataclass encapsulates essential inputs required by the SAKE neural network - model for accurately predicting system energies and properties. It includes atomic positions, atomic types, - and connectivity information, crucial for a detailed representation of atomistic systems. - - Examples - -------- - >>> sake_input = SAKENeuralNetworkInput( - ... atomic_numbers=torch.tensor([1, 6, 6, 8]), - ... positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]]), - ... atomic_subsystem_indices=torch.tensor([0, 0, 0, 0]), - ... pair_indices=torch.tensor([[0, 1], [0, 2], [1, 2]]).T, - ... number_of_atoms=4, - ... atomic_embedding=torch.randn(4, 5) # Example atomic embeddings - ... ) + Custom activation module that multiplies the sigmoid output by a factor of 2.0. + This module is compatible with TorchScript. """ - pair_indices: torch.Tensor - number_of_atoms: int - positions: torch.Tensor - atomic_numbers: torch.Tensor - atomic_subsystem_indices: torch.Tensor - atomic_embedding: torch.Tensor + def __init__(self, factor: float = 2.0): + super().__init__() + self.factor = factor + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.factor * torch.sigmoid(x) -class SAKECore(CoreNetwork): - """SAKE - spatial attention kinetic networks with E(n) equivariance. - Reference: - Wang, Yuanqing and Chodera, John D. ICLR 2023. https://openreview.net/pdf?id=3DIpIf3wQMC +from typing import List, Tuple - """ + +class SAKECore(torch.nn.Module): def __init__( self, - max_Z: int = 100, - number_of_atom_features: int = 64, - number_of_interaction_modules: int = 6, - number_of_spatial_attention_heads: int = 4, - number_of_radial_basis_functions: int = 50, - cutoff: unit.Quantity = 5.0 * unit.angstrom, + featurization: Dict[str, Dict[str, int]], + number_of_interaction_modules: int, + number_of_spatial_attention_heads: int, + number_of_radial_basis_functions: int, + maximum_interaction_radius: float, + activation_function_parameter: Dict[str, str], + predicted_properties: List[str], + predicted_dim: List[int], epsilon: float = 1e-8, + potential_seed: int = -1, ): - from .processing import FromAtomToMoleculeReduction - log.debug("Initializing SAKE model.") + from modelforge.utils.misc import seed_random_number + from modelforge.potential.utils import DenseWithCustomDist + from modelforge.potential import FeaturizeInput + + if potential_seed != -1: + seed_random_number(potential_seed) + log.debug("Initializing the SAKE architecture.") super().__init__() + + self.activation_function = activation_function_parameter["activation_function"] + self.nr_interaction_blocks = number_of_interaction_modules + number_of_per_atom_features = int( + featurization["atomic_number"]["number_of_per_atom_features"] + ) self.nr_heads = number_of_spatial_attention_heads - self.max_Z = max_Z + self.number_of_per_atom_features = number_of_per_atom_features + # featurize the atomic input - self.embedding = Dense(max_Z, number_of_atom_features) + self.featurize_input = FeaturizeInput(featurization) self.energy_layer = nn.Sequential( - Dense(number_of_atom_features, number_of_atom_features), - nn.SiLU(), - Dense(number_of_atom_features, 1), + DenseWithCustomDist( + number_of_per_atom_features, + number_of_per_atom_features, + activation_function=self.activation_function, + ), + DenseWithCustomDist(number_of_per_atom_features, 1), ) # initialize the interaction networks self.interaction_modules = nn.ModuleList( SAKEInteraction( - nr_atom_basis=number_of_atom_features, - nr_edge_basis=number_of_atom_features, - nr_edge_basis_hidden=number_of_atom_features, - nr_atom_basis_hidden=number_of_atom_features, - nr_atom_basis_spatial_hidden=number_of_atom_features, - nr_atom_basis_spatial=number_of_atom_features, - nr_atom_basis_velocity=number_of_atom_features, - nr_coefficients=(self.nr_heads * number_of_atom_features), + nr_atom_basis=number_of_per_atom_features, + nr_edge_basis=number_of_per_atom_features, + nr_edge_basis_hidden=number_of_per_atom_features, + nr_atom_basis_hidden=number_of_per_atom_features, + nr_atom_basis_spatial_hidden=number_of_per_atom_features, + nr_atom_basis_spatial=number_of_per_atom_features, + nr_atom_basis_velocity=number_of_per_atom_features, + nr_coefficients=(self.nr_heads * number_of_per_atom_features), nr_heads=self.nr_heads, - activation=torch.nn.SiLU(), - cutoff=cutoff, + activation=self.activation_function, + maximum_interaction_radius=maximum_interaction_radius, number_of_radial_basis_functions=number_of_radial_basis_functions, epsilon=epsilon, - scale_factor=(1.0 * unit.nanometer), # TODO: switch to angstrom + scale_factor=1.0, ) for _ in range(self.nr_interaction_blocks) ) - - def _model_specific_input_preparation( - self, data: "NNPInput", pairlist_output: "PairListOutputs" - ) -> SAKENeuralNetworkInput: - # Perform atomic embedding - - number_of_atoms = data.atomic_numbers.shape[0] - - atomic_embedding = self.embedding( - F.one_hot(data.atomic_numbers.long(), num_classes=self.max_Z).to( - self.embedding.weight.dtype + # reduce per-atom features to per atom scalar + # Initialize output layers based on configuration + self.output_layers = nn.ModuleDict() + for property, dim in zip(predicted_properties, predicted_dim): + self.output_layers[property] = nn.Sequential( + DenseWithCustomDist( + number_of_per_atom_features, + number_of_per_atom_features, + activation_function=self.activation_function, + ), + DenseWithCustomDist( + number_of_per_atom_features, + int(dim), + ), ) - ) - nnp_input = SAKENeuralNetworkInput( - pair_indices=pairlist_output.pair_indices, - number_of_atoms=number_of_atoms, - positions=data.positions.to(self.embedding.weight.dtype), - atomic_numbers=data.atomic_numbers, - atomic_subsystem_indices=data.atomic_subsystem_indices, - atomic_embedding=atomic_embedding, - ) - - return nnp_input - - def compute_properties(self, data: SAKENeuralNetworkInput): + def compute_properties( + self, data: NNPInput, pairlist_output: PairlistData + ) -> Dict[str, torch.Tensor]: """ - Compute atomic representations/embeddings. + Compute atomic properties. Parameters ---------- - data: SAKENeuralNetworkInput - Dataclass containing atomic properties, embeddings, and pairlist. + data : SAKENeuralNetworkInput + Input data for the SAKE neural network. Returns ------- Dict[str, torch.Tensor] Dictionary containing per-atom energy predictions and atomic subsystem indices. """ - # extract properties from pairlist - h = data.atomic_embedding + h = self.featurize_input(data) x = data.positions v = torch.zeros_like(x) for interaction_mod in self.interaction_modules: - h, x, v = interaction_mod(h, x, v, data.pair_indices) - - # Use squeeze to remove dimensions of size 1 - E_i = self.energy_layer(h).squeeze(1) + h, x, v = interaction_mod(h, x, v, pairlist_output.pair_indices) return { - "per_atom_energy": E_i, + "per_atom_scalar_representation": h, "atomic_subsystem_indices": data.atomic_subsystem_indices, + "atomic_numbers": data.atomic_numbers, } + def forward( + self, data: NNPInput, pairlist_output: PairlistData + ) -> Dict[str, torch.Tensor]: + """ + Implements the forward pass through the network. + + Parameters + ---------- + data : NNPInput + Contains input data for the batch obtained directly from the + dataset, including atomic numbers, positions, and other relevant + fields. + pairlist_output : PairListOutputs + Contains the indices for the selected pairs and their associated + distances and displacement vectors. + + Returns + ------- + Dict[str, torch.Tensor] + The calculated per-atom properties and other properties from the + forward pass. + """ + # perform the forward pass implemented in the subclass + results = self.compute_properties(data, pairlist_output) + # Compute all specified outputs + atomic_embedding = results["per_atom_scalar_representation"] + for output_name, output_layer in self.output_layers.items(): + results[output_name] = output_layer(atomic_embedding) + + return results + class SAKEInteraction(nn.Module): """ @@ -194,10 +193,10 @@ def __init__( nr_coefficients: int, nr_heads: int, activation: nn.Module, - cutoff: unit.Quantity, + maximum_interaction_radius: float, number_of_radial_basis_functions: int, epsilon: float, - scale_factor: unit.Quantity, + scale_factor: float, ): """ Parameters @@ -220,7 +219,7 @@ def __init__( Number of coefficients for spatial attention. activation : Callable Activation function to use. - cutoff : unit.Quantity + maximum_interaction_radius : unit.Quantity Distance parameter for setting scale factors in radial basis functions. number_of_radial_basis_functions: int Number of radial basis functions. @@ -243,31 +242,35 @@ def __init__( self.epsilon = epsilon self.radial_symmetry_function_module = PhysNetRadialBasisFunction( number_of_radial_basis_functions=number_of_radial_basis_functions, - max_distance=cutoff, + max_distance=maximum_interaction_radius, dtype=torch.float32, ) self.node_mlp = nn.Sequential( - Dense( + DenseWithCustomDist( self.nr_atom_basis + self.nr_heads * self.nr_edge_basis + self.nr_atom_basis_spatial, self.nr_atom_basis_hidden, - activation=activation, + activation_function=activation, + ), + DenseWithCustomDist( + self.nr_atom_basis_hidden, + self.nr_atom_basis, + activation_function=activation, ), - Dense(self.nr_atom_basis_hidden, self.nr_atom_basis, activation=activation), ) self.post_norm_mlp = nn.Sequential( - Dense( + DenseWithCustomDist( self.nr_coefficients, self.nr_atom_basis_spatial_hidden, - activation=activation, + activation_function=activation, ), - Dense( + DenseWithCustomDist( self.nr_atom_basis_spatial_hidden, self.nr_atom_basis_spatial, - activation=activation, + activation_function=activation, ), ) @@ -276,40 +279,42 @@ def __init__( ) self.edge_mlp_out = nn.Sequential( - Dense( + DenseWithCustomDist( self.nr_atom_basis * 2 + number_of_radial_basis_functions + 1, self.nr_edge_basis_hidden, - activation=activation, + activation_function=activation, ), nn.Linear(nr_edge_basis_hidden, nr_edge_basis), ) - self.semantic_attention_mlp = Dense( - self.nr_edge_basis, self.nr_heads, activation=nn.CELU(alpha=2.0) + self.semantic_attention_mlp = DenseWithCustomDist( + self.nr_edge_basis, self.nr_heads, activation_function=nn.CELU(alpha=2.0) ) self.velocity_mlp = nn.Sequential( - Dense( - self.nr_atom_basis, self.nr_atom_basis_velocity, activation=activation + DenseWithCustomDist( + self.nr_atom_basis, + self.nr_atom_basis_velocity, + activation_function=activation, ), - Dense( + DenseWithCustomDist( self.nr_atom_basis_velocity, 1, - activation=lambda x: 2.0 * F.sigmoid(x), + activation_function=MultiplySigmoid(factor=2.0), bias=False, ), ) - self.x_mixing_mlp = Dense( + self.x_mixing_mlp = DenseWithCustomDist( self.nr_heads * self.nr_edge_basis, self.nr_coefficients, bias=False, - activation=nn.Tanh(), + activation_function=nn.Tanh(), ) - self.v_mixing_mlp = Dense(self.nr_coefficients, 1, bias=False) + self.v_mixing_mlp = DenseWithCustomDist(self.nr_coefficients, 1, bias=False) - self.scale_factor_in_nanometer = scale_factor.m_as(unit.nanometer) + self.scale_factor_in_nanometer = scale_factor def update_edge(self, h_i_by_pair, h_j_by_pair, d_ij): """Compute intermediate edge features for semantic attention. @@ -331,11 +336,18 @@ def update_edge(self, h_i_by_pair, h_j_by_pair, d_ij): Intermediate edge features. Shape [nr_pairs, nr_edge_basis]. """ h_ij_cat = torch.cat([h_i_by_pair, h_j_by_pair], dim=-1) - h_ij_filtered = self.radial_symmetry_function_module(d_ij.unsqueeze(-1)).squeeze(-2) * self.edge_mlp_in( - h_ij_cat - ) + h_ij_filtered = self.radial_symmetry_function_module( + d_ij.unsqueeze(-1) + ).squeeze(-2) * self.edge_mlp_in(h_ij_cat) return self.edge_mlp_out( - torch.cat([h_ij_cat, h_ij_filtered, d_ij.unsqueeze(-1) / self.scale_factor_in_nanometer], dim=-1) + torch.cat( + [ + h_ij_cat, + h_ij_filtered, + d_ij.unsqueeze(-1) / self.scale_factor_in_nanometer, + ], + dim=-1, + ) ) def update_node(self, h, h_i_semantic, h_i_spatial): @@ -381,7 +393,7 @@ def update_velocity(self, v, h, combinations, idx_i): torch.Tensor Updated velocity features. Shape [nr_of_atoms_in_systems, geometry_basis]. """ - v_ij = self.v_mixing_mlp(combinations.transpose(-1, -2)).squeeze(-1) + v_ij = self.v_mixing_mlp(combinations.transpose(-1, -2)).squeeze(-1).to(v.dtype) expanded_idx_i = idx_i.view(-1, 1).expand_as(v_ij) dv = torch.zeros_like(v).scatter_reduce( 0, expanded_idx_i, v_ij, "mean", include_self=False @@ -408,7 +420,9 @@ def get_combinations(self, h_ij_semantic, dir_ij): # p: nr_pairs, x: geometry_basis, c: nr_coefficients return torch.einsum("px,pc->pcx", dir_ij, self.x_mixing_mlp(h_ij_semantic)) - def get_spatial_attention(self, combinations, idx_i, nr_atoms): + def get_spatial_attention( + self, combinations: torch.Tensor, idx_i: torch.Tensor, nr_atoms: int + ): """Compute spatial attention. Wang and Chodera (2023) Sec. 4 Eq. 6. @@ -438,7 +452,9 @@ def get_spatial_attention(self, combinations, idx_i, nr_atoms): combinations_norm_square = (combinations_mean**2).sum(dim=-1) return self.post_norm_mlp(combinations_norm_square) - def aggregate(self, h_ij_semantic, idx_i, nr_atoms): + def aggregate( + self, h_ij_semantic: torch.Tensor, idx_i: torch.Tensor, nr_atoms: int + ): """Aggregate edge semantic attention over all senders connected to a receiver. Wang and Chodera (2023) Sec. 5 Algorithm 1, step labelled "Neighborhood aggregation". @@ -464,7 +480,13 @@ def aggregate(self, h_ij_semantic, idx_i, nr_atoms): ) return zeros.scatter_add(0, expanded_idx_i, h_ij_semantic) - def get_semantic_attention(self, h_ij_edge, idx_i, idx_j, nr_atoms): + def get_semantic_attention( + self, + h_ij_edge: torch.Tensor, + idx_i: torch.Tensor, + idx_j: torch.Tensor, + nr_atoms: int, + ): """Compute semantic attention. Softmax is over all senders connected to a receiver. Wang and Chodera (2023) Sec. 5 Eq. 9-10. @@ -494,7 +516,6 @@ def get_semantic_attention(self, h_ij_edge, idx_i, idx_j, nr_atoms): expanded_idx_i, dim=0, dim_size=nr_atoms, - device=h_ij_edge.device, ) # p: nr_pairs, f: nr_edge_basis, h: nr_heads return torch.reshape( @@ -522,8 +543,8 @@ def forward( Tuple[torch.Tensor, torch.Tensor, torch.Tensor] Updated scalar and vector representations (h, x, v) with same shapes as input. """ - idx_i, idx_j = pairlist - nr_of_atoms_in_all_systems, _ = x.shape + idx_i, idx_j = pairlist.unbind(0) + nr_of_atoms_in_all_systems = int(x.size(dim=0)) r_ij = x[idx_j] - x[idx_i] d_ij = torch.sqrt((r_ij**2).sum(dim=1) + self.epsilon) dir_ij = r_ij / (d_ij.unsqueeze(-1) + self.epsilon) @@ -546,63 +567,3 @@ def forward( x_updated = x + v_updated return h_updated, x_updated, v_updated - - -from typing import Optional, List, Union - - -class SAKE(BaseNetwork): - def __init__( - self, - max_Z: int, - number_of_atom_features: int, - number_of_interaction_modules: int, - number_of_spatial_attention_heads: int, - number_of_radial_basis_functions: int, - cutoff: unit.Quantity, - postprocessing_parameter: Dict[str, Dict[str, bool]], - dataset_statistic: Optional[Dict[str, float]] = None, - epsilon: float = 1e-8, - ): - from modelforge.utils.units import _convert - self.only_unique_pairs = False # NOTE: for pairlist - super().__init__( - dataset_statistic=dataset_statistic, - postprocessing_parameter=postprocessing_parameter, - cutoff=_convert(cutoff), - ) - - self.core_module = SAKECore( - max_Z=max_Z, - number_of_atom_features=number_of_atom_features, - number_of_interaction_modules=number_of_interaction_modules, - number_of_spatial_attention_heads=number_of_spatial_attention_heads, - number_of_radial_basis_functions=number_of_radial_basis_functions, - cutoff=_convert(cutoff), - epsilon=epsilon, - ) - - - def _config_prior(self): - log.info("Configuring SAKE model hyperparameter prior distribution") - from modelforge.utils.io import import_ - - tune = import_("ray").tune - # from ray import tune - - from modelforge.potential.utils import shared_config_prior - - prior = { - "number_of_atom_features": tune.randint(2, 256), - "number_of_modules": tune.randint(3, 8), - "number_of_spatial_attention_heads": tune.randint(2, 5), - "cutoff": tune.uniform(5, 10), - "number_of_radial_basis_functions": tune.randint(8, 32), - } - prior.update(shared_config_prior()) - return prior - - def combine_per_atom_properties( - self, values: Dict[str, torch.Tensor] - ) -> torch.Tensor: - return values diff --git a/modelforge/potential/schnet.py b/modelforge/potential/schnet.py index 19fb9d8c..8383f09d 100644 --- a/modelforge/potential/schnet.py +++ b/modelforge/potential/schnet.py @@ -1,331 +1,348 @@ -from dataclasses import dataclass, field -from typing import Dict, Optional +""" +SchNet neural network potential for modeling quantum interactions. +""" + +from typing import Dict, List, Type import torch import torch.nn as nn from loguru import logger as log -from openff.units import unit - -from modelforge.potential.utils import NeuralNetworkData -from .models import InputPreparation, NNPInput, BaseNetwork, CoreNetwork, PairListOutputs - - -@dataclass -class SchnetNeuralNetworkData(NeuralNetworkData): - """ - A dataclass to structure the inputs specifically for SchNet-based neural network potentials, including the necessary - geometric and chemical information, along with the radial symmetry function expansion (`f_ij`) and the cosine cutoff - (`f_cutoff`) to accurately represent atomistic systems for energy predictions. - - Attributes - ---------- - pair_indices : torch.Tensor - A 2D tensor of shape [2, num_pairs], indicating the indices of atom pairs within a molecule or system. - d_ij : torch.Tensor - A 1D tensor containing the distances between each pair of atoms identified in `pair_indices`. Shape: [num_pairs, 1]. - r_ij : torch.Tensor - A 2D tensor of shape [num_pairs, 3], representing the displacement vectors between each pair of atoms. - number_of_atoms : int - A integer indicating the number of atoms in the batch. - positions : torch.Tensor - A 2D tensor of shape [num_atoms, 3], representing the XYZ coordinates of each atom within the system. - atomic_numbers : torch.Tensor - A 1D tensor containing atomic numbers for each atom, used to identify the type of each atom in the system(s). - atomic_subsystem_indices : torch.Tensor - A 1D tensor mapping each atom to its respective subsystem or molecule, useful for systems involving multiple - molecules or distinct subsystems. - total_charge : torch.Tensor - A tensor with the total charge of each system or molecule. Shape: [num_systems], where each entry corresponds - to a distinct system or molecule. - atomic_embedding : torch.Tensor - A 2D tensor containing embeddings or features for each atom, derived from atomic numbers. - Shape: [num_atoms, embedding_dim], where `embedding_dim` is the dimensionality of the embedding vectors. - f_ij : Optional[torch.Tensor] - A tensor representing the radial symmetry function expansion of distances between atom pairs, capturing the - local chemical environment. Shape: [num_pairs, num_features], where `num_features` is the dimensionality of - the radial symmetry function expansion. This field will be populated after initialization. - f_cutoff : Optional[torch.Tensor] - A tensor representing the cosine cutoff function applied to the radial symmetry function expansion, ensuring - that atom pair contributions diminish smoothly to zero at the cutoff radius. Shape: [num_pairs]. This field - will be populated after initialization. - - Notes - ----- - The `SchnetNeuralNetworkInput` class is designed to encapsulate all necessary inputs for SchNet-based neural network - potentials in a structured and type-safe manner, facilitating efficient and accurate processing of input data by - the model. The inclusion of radial symmetry functions (`f_ij`) and cosine cutoff functions (`f_cutoff`) allows - for a detailed and nuanced representation of the atomistic systems, crucial for the accurate prediction of system - energies and properties. - - Examples - -------- - >>> inputs = SchnetNeuralNetworkInput( - ... pair_indices=torch.tensor([[0, 1], [0, 2], [1, 2]]), - ... d_ij=torch.tensor([1.0, 1.0, 1.0]), - ... r_ij=torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]), - ... number_of_atoms=3, - ... positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 0.0]]), - ... atomic_numbers=torch.tensor([1, 6, 8]), - ... atomic_subsystem_indices=torch.tensor([0, 0, 0]), - ... total_charge=torch.tensor([0.0]), - ... atomic_embedding=torch.randn(3, 5), # Example atomic embeddings - ... f_ij=torch.randn(3, 4), # Example radial symmetry function expansion - ... f_cutoff=torch.tensor([0.5, 0.5, 0.5]) # Example cosine cutoff function - ... ) - """ - - atomic_embedding: torch.Tensor - f_ij: Optional[torch.Tensor] = field(default=None) - f_cutoff: Optional[torch.Tensor] = field(default=None) - - -class SchNetCore(CoreNetwork): + +from modelforge.utils.prop import NNPInput +from modelforge.potential.neighbors import PairlistData + + +class SchNetCore(torch.nn.Module): def __init__( self, - max_Z: int = 100, - number_of_atom_features: int = 64, - number_of_radial_basis_functions: int = 20, - number_of_interaction_modules: int = 3, - number_of_filters: int = 64, - shared_interactions: bool = False, - cutoff: unit.Quantity = 5.0 * unit.angstrom, + featurization: Dict[str, Dict[str, int]], + number_of_radial_basis_functions: int, + number_of_interaction_modules: int, + maximum_interaction_radius: float, + number_of_filters: int, + activation_function_parameter: Dict[str, str], + shared_interactions: bool, + predicted_properties: List[str], + predicted_dim: List[int], + potential_seed: int = -1, ) -> None: """ - Initialize the SchNet class. + Core SchNet architecture for modeling quantum interactions between atoms. Parameters ---------- - max_Z : int, default=100 - Maximum atomic number to be embedded. - number_of_atom_features : int, default=64 - Dimension of the embedding vectors for atomic numbers. - number_of_radial_basis_functions:int, default=16 - number_of_interaction_modules : int, default=2 - cutoff : openff.units.unit.Quantity, default=5*unit.angstrom - The cutoff distance for interactions. + featurization : Dict[str, Dict[str, int]] + Configuration for atom featurization, including number of features per atom. + number_of_radial_basis_functions : int + Number of radial basis functions for the SchNet representation. + number_of_interaction_modules : int + Number of interaction modules to use. + maximum_interaction_radius : float + Maximum distance for interactions. + number_of_filters : int + Number of filters for interaction layers. + activation_function_parameter : Dict[str, str] + Dictionary containing the activation function to use. + shared_interactions : bool + Whether to share weights across all interaction modules. + predicted_properties : List[str] + List of properties to predict. + predicted_dim : List[int] + List of dimensions for each predicted property. + potential_seed : int, optional + Seed for random number generation, by default -1. """ - from .utils import Dense, ShiftedSoftplus - log.debug("Initializing SchNet model.") - super().__init__() - self.number_of_atom_features = number_of_atom_features - self.number_of_filters = number_of_filters or self.number_of_atom_features - self.number_of_radial_basis_functions = number_of_radial_basis_functions + from modelforge.utils.misc import seed_random_number - # embedding - from modelforge.potential.utils import Embedding + if potential_seed != -1: + seed_random_number(potential_seed) - self.embedding_module = Embedding(max_Z, number_of_atom_features) + super().__init__() + self.activation_function = activation_function_parameter["activation_function"] - # Initialize representation block - self.schnet_representation_module = SchNETRepresentation( - cutoff, number_of_radial_basis_functions + log.debug("Initializing the SchNet architecture.") + from modelforge.potential.utils import DenseWithCustomDist + + # Set the number of filters and atom features + self.number_of_filters = number_of_filters or int( + featurization["atomic_number"]["number_of_per_atom_features"] ) - # Intialize interaction blocks - self.interaction_modules = nn.ModuleList( - [ - SchNETInteractionModule( - self.number_of_atom_features, - self.number_of_filters, - number_of_radial_basis_functions, - ) - for _ in range(number_of_interaction_modules) - ] + self.number_of_radial_basis_functions = number_of_radial_basis_functions + number_of_per_atom_features = int( + featurization["atomic_number"]["number_of_per_atom_features"] ) - # output layer to obtain per-atom energies - self.energy_layer = nn.Sequential( - Dense( - number_of_atom_features, - number_of_atom_features, - activation=ShiftedSoftplus(), - ), - Dense( - number_of_atom_features, - 1, - ), + # Initialize representation block for SchNet + self.schnet_representation_module = SchNETRepresentation( + maximum_interaction_radius, + number_of_radial_basis_functions, + featurization_config=featurization, ) + # Initialize interaction blocks, sharing or not based on config + if shared_interactions: + self.interaction_modules = nn.ModuleList( + [ + SchNETInteractionModule( + number_of_per_atom_features, + self.number_of_filters, + number_of_radial_basis_functions, + activation_function=self.activation_function, + ) + ] + * number_of_interaction_modules + ) - def _model_specific_input_preparation( - self, data: "NNPInput", pairlist_output: "PairListOutputs" - ) -> SchnetNeuralNetworkData: - number_of_atoms = data.atomic_numbers.shape[0] - - nnp_input = SchnetNeuralNetworkData( - pair_indices=pairlist_output.pair_indices, - d_ij=pairlist_output.d_ij, - r_ij=pairlist_output.r_ij, - number_of_atoms=number_of_atoms, - positions=data.positions, - atomic_numbers=data.atomic_numbers, - atomic_subsystem_indices=data.atomic_subsystem_indices, - total_charge=data.total_charge, - atomic_embedding=self.embedding_module( - data.atomic_numbers - ), # atom embedding - ) + else: + self.interaction_modules = nn.ModuleList( + [ + SchNETInteractionModule( + number_of_per_atom_features, + self.number_of_filters, + number_of_radial_basis_functions, + activation_function=self.activation_function, + ) + for _ in range(number_of_interaction_modules) + ] + ) - return nnp_input + # Initialize output layers based on predicted properties + self.output_layers = nn.ModuleDict() + for property, dim in zip(predicted_properties, predicted_dim): + self.output_layers[property] = nn.Sequential( + DenseWithCustomDist( + number_of_per_atom_features, + number_of_per_atom_features, + activation_function=self.activation_function, + ), + DenseWithCustomDist( + number_of_per_atom_features, + int(dim), + ), + ) def compute_properties( - self, data: SchnetNeuralNetworkData + self, data: NNPInput, pairlist_output: PairlistData ) -> Dict[str, torch.Tensor]: """ - Calculate the energy for a given input batch. + Compute properties based on the input data and pair list. Parameters ---------- - data : NamedTuple + data : NNPInput + Input data including atomic numbers, positions, etc. + pairlist_output: PairlistData + Output from the pairlist module, containing pair indices and + distances. Returns ------- Dict[str, torch.Tensor] - Calculated energies; shape (nr_systems,). + A dictionary containing the computed properties for each atom. """ + # Compute the atomic representation + representation = self.schnet_representation_module(data, pairlist_output) + atomic_embedding = representation["atomic_embedding"] + f_ij = representation["f_ij"] + f_cutoff = representation["f_cutoff"] - # Compute the representation for each atom (transform to radial basis set, multiply by cutoff) - representation = self.schnet_representation_module(data.d_ij) - data.f_ij = representation["f_ij"] - data.f_cutoff = representation["f_cutoff"] - - x = data.atomic_embedding - # Iterate over interaction blocks to update features + # Apply interaction modules to update the atomic embedding for interaction in self.interaction_modules: - v = interaction( - x, - data.pair_indices, - representation["f_ij"], - representation["f_cutoff"], + atomic_embedding = atomic_embedding + interaction( + atomic_embedding, + pairlist_output, + f_ij, + f_cutoff, ) - x = x + v # Update atomic features - - E_i = self.energy_layer(x).squeeze(1) return { - "per_atom_energy": E_i, - "scalar_representation": x, + "per_atom_scalar_representation": atomic_embedding, "atomic_subsystem_indices": data.atomic_subsystem_indices, + "atomic_numbers": data.atomic_numbers, } + def forward( + self, data: NNPInput, pairlist_output: PairlistData + ) -> Dict[str, torch.Tensor]: + """ + Forward pass of the SchNet model. + + Parameters + ---------- + data : NNPInput + Input data including atomic numbers, positions, and relevant fields. + pairlist_output : PairlistData + Pair indices and distances from the pairlist module. + + Returns + ------- + Dict[str, torch.Tensor] + A dictionary of calculated properties from the forward pass. + """ + # Compute properties using the core method + results = self.compute_properties(data, pairlist_output) + atomic_embedding = results["per_atom_scalar_representation"] + + # Apply output layers to the atomic embedding + for output_name, output_layer in self.output_layers.items(): + results[output_name] = output_layer(atomic_embedding) + + return results + class SchNETInteractionModule(nn.Module): + def __init__( self, - number_of_atom_features: int, + number_of_per_atom_features: int, number_of_filters: int, number_of_radial_basis_functions: int, + activation_function: torch.nn.Module, ) -> None: """ - Initialize the SchNet interaction block. + SchNet interaction module to compute interaction terms based on atomic + distances and features. Parameters ---------- - number_of_atom_features : int - Number of atom ffeatures, defines the dimensionality of the embedding. + number_of_per_atom_features : int + Number of atom features, defines the dimensionality of the + embedding. number_of_filters : int - Number of filters, defines the dimensionality of the intermediate features. + Number of filters, defines the dimensionality of the intermediate + features. number_of_radial_basis_functions : int Number of radial basis functions. + activation_function : torch.nn.Module + The activation function to use in the interaction module. """ + super().__init__() - from .utils import Dense, ShiftedSoftplus + from .utils import DenseWithCustomDist assert ( number_of_radial_basis_functions > 4 ), "Number of radial basis functions must be larger than 10." assert number_of_filters > 1, "Number of filters must be larger than 1." assert ( - number_of_atom_features > 10 + number_of_per_atom_features > 10 ), "Number of atom basis must be larger than 10." - self.number_of_atom_features = number_of_atom_features # Initialize parameters - self.intput_to_feature = Dense( - number_of_atom_features, number_of_filters, bias=False, activation=None + self.number_of_per_atom_features = ( + number_of_per_atom_features # Initialize parameters + ) + + # Define input, filter, and output layers + self.intput_to_feature = DenseWithCustomDist( + number_of_per_atom_features, + number_of_filters, + bias=False, ) self.feature_to_output = nn.Sequential( - Dense( - number_of_filters, number_of_atom_features, activation=ShiftedSoftplus() + DenseWithCustomDist( + number_of_filters, + number_of_per_atom_features, + activation_function=activation_function, + ), + DenseWithCustomDist( + number_of_per_atom_features, + number_of_per_atom_features, ), - Dense(number_of_atom_features, number_of_atom_features, activation=None), ) self.filter_network = nn.Sequential( - Dense( + DenseWithCustomDist( number_of_radial_basis_functions, number_of_filters, - activation=ShiftedSoftplus(), + activation_function=activation_function, + ), + DenseWithCustomDist( + number_of_filters, + number_of_filters, ), - Dense(number_of_filters, number_of_filters, activation=None), ) def forward( self, - x: torch.Tensor, - pairlist: torch.Tensor, # shape [n_pairs, 2] - f_ij: torch.Tensor, # shape [n_pairs, number_of_radial_basis_functions] - f_ij_cutoff: torch.Tensor, # shape [n_pairs, 1] + atomic_embedding: torch.Tensor, + pairlist: PairlistData, + f_ij: torch.Tensor, + f_ij_cutoff: torch.Tensor, ) -> torch.Tensor: """ Forward pass for the interaction block. Parameters ---------- - x : torch.Tensor, shape [nr_of_atoms_in_systems, nr_atom_basis] + atomic_embedding : torch.Tensor Input feature tensor for atoms (output of embedding). - pairlist : torch.Tensor, shape [n_pairs, 2] - f_ij : torch.Tensor, shape [n_pairs, 1, number_of_radial_basis_functions] + pairlist : PairlistData + List of atom pairs. + f_ij : torch.Tensor, shape [n_pairs, number_of_radial_basis_functions] Radial basis functions for pairs of atoms. f_ij_cutoff : torch.Tensor, shape [n_pairs, 1] + Cutoff values for the pairs. Returns ------- torch.Tensor, shape [nr_of_atoms_in_systems, nr_atom_basis] Updated feature tensor after interaction block. """ - idx_i, idx_j = pairlist[0], pairlist[1] + idx_i, idx_j = pairlist.pair_indices[0], pairlist.pair_indices[1] - # Map input features to the filter space - x = self.intput_to_feature(x) + # Transform atomic embedding to filter space + atomic_embedding = self.intput_to_feature(atomic_embedding) # Generate interaction filters based on radial basis functions - W_ij = self.filter_network(f_ij.squeeze(1)) # FIXME - W_ij = W_ij * f_ij_cutoff + W_ij = self.filter_network(f_ij.squeeze(1)) + W_ij = W_ij * f_ij_cutoff # Shape: [n_pairs, number_of_filters] # Perform continuous-filter convolution - x_j = x[idx_j] - x_ij = x_j * W_ij # (nr_of_atom_pairs, nr_atom_basis) - out = torch.zeros_like(x) - out.scatter_add_(0, idx_i.unsqueeze(-1).expand_as(x_ij), x_ij) # from per_atom_pair to _per_atom + x_j = atomic_embedding[idx_j] + x_ij = x_j * W_ij # Element-wise multiplication + + out = torch.zeros_like(atomic_embedding).scatter_add_( + 0, idx_i.unsqueeze(-1).expand_as(x_ij), x_ij + ) # Aggregate per-atom pair to per-atom - return self.feature_to_output(out) # shape: (nr_of_atoms, 1) + return self.feature_to_output(out) # Output the updated atomic features class SchNETRepresentation(nn.Module): + def __init__( self, - radial_cutoff: unit.Quantity, + radial_cutoff: float, number_of_radial_basis_functions: int, + featurization_config: Dict[str, Dict[str, int]], ): """ - Initialize the SchNet representation layer. + SchNet representation module to generate the radial symmetry + representation of pairwise distances. Parameters ---------- - Radial Basis Function Module + radial_cutoff : float + The cutoff distance for interactions in nanometer. + number_of_radial_basis_functions : int + Number of radial basis functions. + featurization_config : Dict[str, Dict[str, int]] + Configuration for atom featurization. """ super().__init__() self.radial_symmetry_function_module = self._setup_radial_symmetry_functions( radial_cutoff, number_of_radial_basis_functions ) - # cutoff - from modelforge.potential import CosineCutoff + # Initialize cutoff module + from modelforge.potential import CosineAttenuationFunction, FeaturizeInput - self.cutoff_module = CosineCutoff(radial_cutoff) + self.featurize_input = FeaturizeInput(featurization_config) + self.cutoff_module = CosineAttenuationFunction(radial_cutoff) def _setup_radial_symmetry_functions( - self, radial_cutoff: unit.Quantity, number_of_radial_basis_functions: int + self, radial_cutoff: float, number_of_radial_basis_functions: int ): - from .utils import SchnetRadialBasisFunction + from modelforge.potential import SchnetRadialBasisFunction radial_symmetry_function = SchnetRadialBasisFunction( number_of_radial_basis_functions=number_of_radial_basis_functions, @@ -334,103 +351,34 @@ def _setup_radial_symmetry_functions( ) return radial_symmetry_function - def forward(self, d_ij: torch.Tensor) -> Dict[str, torch.Tensor]: + def forward( + self, data: NNPInput, pairlist_output: PairlistData + ) -> Dict[str, torch.Tensor]: """ - Generate the radial symmetry representation of the pairwise distances. + Forward pass to generate the radial symmetry representation of pairwise + distances. Parameters ---------- - d_ij : Pairwise distances between atoms; shape [n_pairs, 1] + data : NNPInput + Input data containing atomic numbers and positions. + pairlist_output : PairlistData + Output from the pairlist module, containing pair indices and distances. Returns ------- - Radial basis functions for pairs of atoms; shape [n_pairs, 1, number_of_radial_basis_functions] + Dict[str, torch.Tensor] + A dictionary containing radial basis functions, cutoff values, and atomic embeddings. """ # Convert distances to radial basis functions - f_ij = self.radial_symmetry_function_module( - d_ij - ) # shape (n_pairs, number_of_radial_basis_functions) - - f_cutoff = self.cutoff_module(d_ij) # shape (n_pairs, 1) - - return {"f_ij": f_ij, "f_cutoff": f_cutoff} - - -from typing import List, Union - - -class SchNet(BaseNetwork): - def __init__( - self, - max_Z: int, - number_of_atom_features: int, - number_of_radial_basis_functions: int, - number_of_interaction_modules: int, - cutoff: Union[unit.Quantity, str], - number_of_filters: int, - shared_interactions: bool, - postprocessing_parameter: Dict[str, Dict[str, bool]], - dataset_statistic: Optional[Dict[str, float]] = None, - ) -> None: - """ - Initialize the SchNet network. + f_ij = self.radial_symmetry_function_module(pairlist_output.d_ij) - Schütt, Kindermans, Sauceda, Chmiela, Tkatchenko, Müller: - SchNet: A continuous-filter convolutional neural network for modeling quantum - interactions. + # Apply cutoff function to distances + f_cutoff = self.cutoff_module(pairlist_output.d_ij) # shape (n_pairs, 1) - Parameters - ---------- - max_Z : int, default=100 - Maximum atomic number to be embedded. - number_of_atom_features : int, default=64 - Dimension of the embedding vectors for atomic numbers. - number_of_radial_basis_functions:int, default=16 - number_of_interaction_modules : int, default=2 - cutoff : openff.units.unit.Quantity, default=5*unit.angstrom - The cutoff distance for interactions. - """ - from modelforge.utils.units import _convert - - self.only_unique_pairs = False # NOTE: need to be set before super().__init__ - - super().__init__( - dataset_statistic=dataset_statistic, - postprocessing_parameter=postprocessing_parameter, - cutoff=_convert(cutoff), - ) - - self.core_module = SchNetCore( - max_Z=max_Z, - number_of_atom_features=number_of_atom_features, - number_of_radial_basis_functions=number_of_radial_basis_functions, - number_of_interaction_modules=number_of_interaction_modules, - number_of_filters=number_of_filters, - shared_interactions=shared_interactions, - ) - - def _config_prior(self): - log.info("Configuring SchNet model hyperparameter prior distribution") - from modelforge.utils.io import import_ - - tune = import_("ray").tune - # from ray import tune - - from modelforge.potential.utils import shared_config_prior - - prior = { - "number_of_atom_features": tune.randint(2, 256), - "number_of_interaction_modules": tune.randint(1, 5), - "cutoff": tune.uniform(5, 10), - "number_of_radial_basis_functions": tune.randint(8, 32), - "number_of_filters": tune.randint(32, 128), - "shared_interactions": tune.choice([True, False]), + return { + "f_ij": f_ij, + "f_cutoff": f_cutoff, + "atomic_embedding": self.featurize_input(data), } - prior.update(shared_config_prior()) - return prior - - def combine_per_atom_properties( - self, values: Dict[str, torch.Tensor] - ) -> torch.Tensor: - return values diff --git a/modelforge/potential/tensornet.py b/modelforge/potential/tensornet.py index 5b3711ca..9f275a61 100644 --- a/modelforge/potential/tensornet.py +++ b/modelforge/potential/tensornet.py @@ -1,106 +1,827 @@ -from dataclasses import dataclass +""" +TensorNet network for molecular potential learning. +""" + +from typing import Dict, List, Tuple import torch -from openff.units import unit +from torch import nn + +from modelforge.potential import CosineAttenuationFunction, TensorNetRadialBasisFunction -from modelforge.potential.models import InputPreparation -from modelforge.potential.models import BaseNetwork -from modelforge.potential.utils import NeuralNetworkData +from modelforge.utils.prop import NNPInput +from modelforge.potential.neighbors import PairlistData -class TensorNet(BaseNetwork): +class DenseAndSum(nn.Module): def __init__( self, - radial_max_distance: unit.Quantity = 5.1 * unit.angstrom, - radial_min_distanc: unit.Quantity = 0.0 * unit.angstrom, - number_of_radial_basis_functions: int = 16, + input_dim: int, + output_dim: int, + sum_dim: int, + ): + """ + A dense (fully connected) layer followed by a summation over a specified dimension. + + Parameters + ---------- + input_dim : int + Input dimensionality of the dense layer. + output_dim : int + Output dimensionality of the dense layer. + sum_dim : int + Dimension over which to sum the result after applying the dense layer. + """ + super().__init__() + self.dense = nn.Linear(input_dim, output_dim) + self.sum_dim = sum_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the DenseAndSum layer. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Output tensor after applying the dense layer and summing over the specified dimension. + """ + x = self.dense(x) + return x.sum(dim=self.sum_dim) + + +def vector_to_skewtensor(r_ij_norm: torch.Tensor) -> torch.Tensor: + """ + Creates a skew-symmetric tensor (A) from a vector + (equation 3 in TensorNet paper). + + Parameters + ---------- + r_ij_norm : torch.Tensor + Normalized displacement vectors of given atom pairs. + + Returns + ------- + torch.Tensor + Matrix A from equation 3 in TensorNet paper. + """ + + zero = torch.zeros_like(r_ij_norm[:, 0]) + out = torch.stack( + ( + zero, + -r_ij_norm[:, 2], + r_ij_norm[:, 1], + r_ij_norm[:, 2], + zero, + -r_ij_norm[:, 0], + -r_ij_norm[:, 1], + r_ij_norm[:, 0], + zero, + ), + dim=-1, + ).view(-1, 3, 3) + + return out + + +def vector_to_symtensor(r_ij_norm: torch.Tensor) -> torch.Tensor: + """ + Creates a symmetric traceless tensor (S) from the outer product of a vector + with itself (equation 3 in TensorNet paper). + + Parameters + ---------- + r_ij_norm : torch.Tensor + Normalized displacement vectors of given atom pairs. + + Returns + ------- + torch.Tensor + Matrix S from equation 3 in TensorNet paper. + """ + + r_ij_norm = r_ij_norm.unsqueeze(-1) * r_ij_norm.unsqueeze(-2) + I = torch.eye(3, device=r_ij_norm.device, dtype=r_ij_norm.dtype) * ( + r_ij_norm.diagonal(dim1=-2, dim2=-1).mean(-1, keepdim=True) + ).unsqueeze(-1) + S = 0.5 * (r_ij_norm + r_ij_norm.transpose(-1, -2)) - I + return S + + +def decompose_tensor( + tensor: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Decomposes a tensor into its irreducible components (I, A, S) (Equation 2 and 3 in TensorNet paper). + + Parameters + ---------- + tensor : torch.Tensor + Input tensor representing pair-wise features of the atomic system, shape (n_atoms, 3, 3). + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor] + Decomposed components: Identity tensor I, skew-symmetric tensor A, and symmetric traceless tensor S. + """ + + diag_mean = tensor.diagonal(offset=0, dim1=-1, dim2=-2).mean(-1) + I = diag_mean[..., None, None] * torch.eye( + 3, 3, device=tensor.device, dtype=tensor.dtype + ) + A = 0.5 * (tensor - tensor.transpose(-2, -1)) + S = 0.5 * (tensor + tensor.transpose(-2, -1)) - I + + return I, A, S + + +def tensor_norm(tensor: torch.Tensor) -> torch.Tensor: + """ + Compute Frobenius norm + (mentioned at the end of section 3.1 in TensorNet paper). + + Parameters + ---------- + tensor : torch.Tensor + Input tensor, shape (n_atoms, 3, 3). + + Returns + ------- + torch.Tensor + Frobenius norm of the input tensor. + """ + # Note: the Frobenius norm is actually the square root of the sum of squares, so assert torch.allclose(torch.norm(tensor, p="fro", dim=(-2, -1))**2, (tensor**2).sum((-2, -1)) == True + return (tensor**2).sum((-2, -1)) + + +def tensor_message_passing( + pair_indices: torch.Tensor, + radial_feature_vector: torch.Tensor, + tensor: torch.Tensor, + tensor_shape: Tuple[int, int, int, int], +) -> torch.Tensor: + """ + Helper function to calculate message passing tensor M + ("Interaction and node update", section 3.2 in TensorNet paper). + Tensor I, A, and S are parsed separately into this helper function. + + Parameters + ---------- + pair_indices : torch.Tensor + A pair-wise index tensor specifying the corresponding atomic pairs. + radial_feature_vector : torch.Tensor + Radial feature vector calculated through TensorNetRadialBasisFunction. + tensor : torch.Tensor + A pair-wise feature tensor decomposed term (I, A, or S). + number_of_atoms : int + Number of atoms in the system. + + Returns + ------- + torch.Tensor + A Message tensor calculated from I, A, or S. + """ + + # Compute the message for each pair + msg = radial_feature_vector * tensor.index_select(0, pair_indices[1]) + # Pre-allocate tensor for the aggregated messages + tensor_m = torch.zeros(tensor_shape, device=tensor.device, dtype=tensor.dtype) + # Aggregate the messages, using in-place addition to avoid unnecessary + # copies + tensor_m.index_add_(0, pair_indices[0], msg) + return tensor_m + + +class TensorNetCore(torch.nn.Module): + def __init__( + self, + number_of_per_atom_features: int, + number_of_interaction_layers: int, + number_of_radial_basis_functions: int, + maximum_interaction_radius: float, + minimum_interaction_radius: float, + maximum_atomic_number: int, + equivariance_invariance_group: str, + activation_function_parameter: Dict[str, str], + predicted_properties: List[str], + predicted_dim: List[int], + potential_seed: int = -1, + trainable_centers_and_scale_factors: bool = False, ) -> None: + """ + Core TensorNet model for molecular potential learning. + Parameters + ---------- + number_of_per_atom_features : int + Number of features per atom. + number_of_interaction_layers : int + Number of interaction layers in the network. + number_of_radial_basis_functions : int + Number of radial basis functions. + maximum_interaction_radius : float + Maximum interaction radius for atomic interactions. + minimum_interaction_radius : float + Minimum interaction radius for atomic interactions. + maximum_atomic_number : int + Maximum atomic number allowed for the model. + equivariance_invariance_group : str + Specifies the equivariance invariance group ("O(3)" or "SO(3)"). + activation_function_parameter : Dict[str, str] + Activation function configuration. + predicted_properties : List[str] + List of properties to predict. + predicted_dim : List[int] + List of output dimensions for each predicted property. + potential_seed : int, optional + Random seed for reproducibility. Default is -1. + trainable_centers_and_scale_factors : bool, optional + Whether the centers and scale factors for the radial basis functions are trainable. Default is False. + """ super().__init__() + activation_function = activation_function_parameter["activation_function"] + + from modelforge.utils.misc import seed_random_number - self.core_module = TensorNetCore( - radial_max_distance, - radial_min_distanc, - number_of_radial_basis_functions, + if potential_seed != -1: + seed_random_number(potential_seed) + + self.representation_module = TensorNetRepresentation( + number_of_per_atom_features=number_of_per_atom_features, + number_of_radial_basis_functions=number_of_radial_basis_functions, + activation_function=activation_function, + maximum_interaction_radius=maximum_interaction_radius, + minimum_interaction_radius=minimum_interaction_radius, + trainable_centers_and_scale_factors=trainable_centers_and_scale_factors, + maximum_atomic_number=maximum_atomic_number, ) - self.only_unique_pairs = True # NOTE: for pairlist - self.input_preparation = InputPreparation( - cutoff=radial_max_distance, only_unique_pairs=self.only_unique_pairs + self.interaction_modules = nn.ModuleList( + [ + TensorNetInteraction( + number_of_per_atom_features=number_of_per_atom_features, + number_of_radial_basis_functions=number_of_radial_basis_functions, + activation_function=activation_function, + maximum_interaction_radius=maximum_interaction_radius, + equivariance_invariance_group=equivariance_invariance_group, + ) + for _ in range(number_of_interaction_layers) + ] ) + # Initialize output layers based on configuration + self.output_layers = nn.ModuleDict() + for property, dim in zip(predicted_properties, predicted_dim): + self.output_layers[property] = DenseAndSum( + 3 * number_of_per_atom_features, + number_of_per_atom_features, + dim, + ) + + self.perform_layer_normalization = nn.LayerNorm(3 * number_of_per_atom_features) + + def compute_properties( + self, + data: NNPInput, + pairlist_output: PairlistData, + ) -> Dict[str, torch.Tensor]: + """ + Compute properties for the TensorNet model. + + Parameters + ---------- + data : NNPInput + The input data for the model. + pairlist_output : PairlistData + The pair list data including distances and indices of atom pairs. + + Returns + ------- + Dict[str, torch.Tensor] + The calculated properties, including atomic subsystem indices and + atomic numbers. + """ + + # generate initial embedding + X, radial_feature_vector = self.representation_module(data, pairlist_output) + + # using interlevae and bincount to generate a total charge per molecule + expanded_total_charge = torch.repeat_interleave( + data.per_system_total_charge, data.atomic_subsystem_indices.bincount() + ) + + for layer in self.interaction_modules: + X = layer( + X, + pairlist_output.pair_indices, + pairlist_output.d_ij.squeeze(-1), + radial_feature_vector.squeeze(1), + expanded_total_charge, + ) + + I, A, S = decompose_tensor(X) + + per_atom_scalar_representation = torch.cat( + (tensor_norm(I), tensor_norm(A), tensor_norm(S)), + dim=-1, + ) + + per_atom_scalar_representation = self.perform_layer_normalization( + per_atom_scalar_representation + ) + + return { + "per_atom_scalar_representation": per_atom_scalar_representation, + "atomic_subsystem_indices": data.atomic_subsystem_indices, + "atomic_numbers": data.atomic_numbers, + } + + def forward( + self, + data: NNPInput, + pairlist_output: PairlistData, + ) -> Dict[str, torch.Tensor]: + """ + Forward pass through the TensorNet model. + + Parameters + ---------- + data : NNPInput + Input data including atomic numbers and positions. + pairlist_output : PairlistData + Pair list output with distances and displacement vectors. + + Returns + ------- + Dict[str, torch.Tensor] + Calculated per-atom properties from the forward pass. + """ + # perform the forward pass implemented in the subclass + results = self.compute_properties(data, pairlist_output) + # extract the atomic embedding + atomic_embedding = results["per_atom_scalar_representation"] + # Compute all specified outputs + for output_name, output_layer in self.output_layers.items(): + results[output_name] = output_layer(atomic_embedding).unsqueeze(1) + + return results + + +class TensorNetRepresentation(torch.nn.Module): -class TensorNetCore(torch.nn.Module): def __init__( self, - radial_max_distance: unit.Quantity, - radial_min_distanc: unit.Quantity, + number_of_per_atom_features: int, number_of_radial_basis_functions: int, + activation_function: nn.Module, + maximum_interaction_radius: float, + minimum_interaction_radius: float, + trainable_centers_and_scale_factors: bool, + maximum_atomic_number: int, ): + """ + TensorNet representation module for molecular systems. + + Parameters + ---------- + number_of_per_atom_features : int + Number of features per atom. + number_of_radial_basis_functions : int + Number of radial basis functions. + activation_function : nn.Module + Activation function class. + maximum_interaction_radius : float + Maximum interaction radius in nanometer. + minimum_interaction_radius : float + Minimum interaction radius in nanometer. + trainable_centers_and_scale_factors : bool + If True, centers and scale factors are trainable. + maximum_atomic_number : int + Maximum atomic number in the dataset. + """ super().__init__() + from modelforge.potential.utils import Dense + + self.number_of_per_atom_features = number_of_per_atom_features - # Initialize representation block - self.tensornet_representation_module = TensorNetRepresentation( - radial_max_distance, - radial_min_distanc, - number_of_radial_basis_functions, + self.cutoff_module = CosineAttenuationFunction(maximum_interaction_radius) + + self.radial_symmetry_function = TensorNetRadialBasisFunction( + number_of_radial_basis_functions=number_of_radial_basis_functions, + max_distance=maximum_interaction_radius, + min_distance=minimum_interaction_radius, + alpha=( + (maximum_interaction_radius - minimum_interaction_radius) / 5.0 + ), # TensorNet uses angstrom + trainable_centers_and_scale_factors=trainable_centers_and_scale_factors, + ) + self.rsf_projections = nn.ModuleDict( + { + "I": nn.Linear( + number_of_radial_basis_functions, number_of_per_atom_features + ), + "A": nn.Linear( + number_of_radial_basis_functions, number_of_per_atom_features + ), + "S": nn.Linear( + number_of_radial_basis_functions, number_of_per_atom_features + ), + } + ) + self.atomic_number_i_embedding_layer = nn.Embedding( + maximum_atomic_number, + number_of_per_atom_features, + ) + self.atomic_number_ij_embedding_layer = nn.Linear( + 2 * number_of_per_atom_features, + number_of_per_atom_features, + ) + self.activation_function = activation_function + # initialize linear layer for I, A and S + self.linears_tensor = nn.ModuleList( + [ + nn.Linear( + number_of_per_atom_features, number_of_per_atom_features, bias=False + ) + for _ in range(3) + ] + ) + self.linears_scalar = nn.Sequential( + *[ + Dense( + number_of_per_atom_features, + 2 * number_of_per_atom_features, + bias=True, + activation_function=self.activation_function, + ), + Dense( + 2 * number_of_per_atom_features, + 3 * number_of_per_atom_features, + bias=True, + activation_function=self.activation_function, + ), + ] ) + self.batch_layer_normalization = nn.LayerNorm(number_of_per_atom_features) + self.reset_parameters() - self.interaction_modules = ANIInteraction() + def reset_parameters(self): + """ + Initialize neural network parameters of the representation layer. + """ + self.rsf_projections["I"].reset_parameters() + self.rsf_projections["A"].reset_parameters() + self.rsf_projections["S"].reset_parameters() + self.atomic_number_i_embedding_layer.reset_parameters() + self.atomic_number_ij_embedding_layer.reset_parameters() + for linear in self.linears_tensor: + linear.reset_parameters() + for linear in self.linears_scalar: + linear.reset_parameters() + self.batch_layer_normalization.reset_parameters() - def _forward(self): - pass + def _get_atomic_number_message( + self, + atomic_number: torch.Tensor, + pair_indices: torch.Tensor, + ) -> torch.Tensor: + """ + Get the atomic number embedding for each atom pair. - def _model_specific_input_preparation(self): - pass + (mentioned in equation 8 in TensorNet paper, not explicitly defined). + This embedding consists of two steps: + 1. embed atom type of each atom into a vector + 2. the embedding of an atom pair is the linear combination of the + embedding vector of these two atoms in the atom pair - def forward(self): - pass + Parameters + ---------- + atomic_number : torch.Tensor + A tensor includes atomic numbers for every atom in the system. + pair_indices : torch.Tensor + A pair-wise index tensor specifying the corresponding atomic pairs. + Returns + ------- + torch.Tensor + The embedding tensor for atomic numbers of atom pairs. + """ + atomic_number_i_embedding = self.atomic_number_i_embedding_layer(atomic_number) + pair_indices_flat = pair_indices.t().reshape(-1) -class TensorNetRepresentation(torch.nn.Module): + atomic_number_ij_embedding = self.atomic_number_ij_embedding_layer( + atomic_number_i_embedding[pair_indices_flat].view( + -1, self.number_of_per_atom_features * 2 + ) + )[..., None, None] + return atomic_number_ij_embedding + + def _get_tensor_messages( + self, + atomic_number_embedding: torch.Tensor, + d_ij: torch.Tensor, + r_ij_norm: torch.Tensor, + radial_feature_vector: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Generate I, A, and S tensor messages for atom pairs. + (equation 8 in TensorNet paper). + + Parameters + ---------- + atomic_number_embedding : torch.Tensor + The embedding tensor for atomic numbers of atom pairs. + d_ij : torch.Tensor + Atomic pair-wise distances. + r_ij_norm : torch.Tensor + normalized displacement vectors, by dividing r_ij by d_ij + radial_feature_vector : torch.Tensor + Radial feature vector calculated through + TensorNetRadialBasisFunction. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor] + The Iij, Aij, Sij terms in equation 8, before adding up since these + three terms are treated separately. + """ + + C = self.cutoff_module(d_ij).reshape(-1, 1, 1, 1) * atomic_number_embedding + eye = torch.eye(3, 3, device=r_ij_norm.device, dtype=r_ij_norm.dtype)[ + None, None, ... + ] + Iij = ( + self.rsf_projections["I"](radial_feature_vector).permute(0, 2, 1)[..., None] + * C + * eye + ) + Aij = ( + self.rsf_projections["A"](radial_feature_vector).permute(0, 2, 1)[..., None] + * C + * vector_to_skewtensor(r_ij_norm)[..., None, :, :] + ) + Sij = ( + self.rsf_projections["S"](radial_feature_vector).permute(0, 2, 1)[..., None] + * C + * vector_to_symtensor(r_ij_norm)[..., None, :, :] + ) + return Iij, Aij, Sij + + def forward( + self, + data: NNPInput, + pairlist_output: PairlistData, + ): + """ + Forward pass for the representation module. + (equation 10 in TensorNet paper). + + Parameters + ---------- + data : NNPInput + Input data for the system, including atomic numbers and positions. + pairlist_output : PairlistData + Output from the pair list module, including pair indices and distances. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + The first value is the X tensor as a representation of the system. + The second value is the radial feature vector that is required + by compute_properties of TensorNetCore. + """ + atomic_number_embedding = self._get_atomic_number_message( + data.atomic_numbers, + pairlist_output.pair_indices, + ) + r_ij_norm = pairlist_output.r_ij / pairlist_output.d_ij + + radial_feature_vector = self.radial_symmetry_function( + pairlist_output.d_ij + ) # in nanometer + rcut_ij = self.cutoff_module( + pairlist_output.d_ij + ) # cutoff function applied twice + radial_feature_vector = torch.mul(radial_feature_vector, rcut_ij).unsqueeze(1) + + Iij, Aij, Sij = self._get_tensor_messages( + atomic_number_embedding, + pairlist_output.d_ij, + r_ij_norm, + radial_feature_vector, + ) + source = torch.zeros( + data.atomic_numbers.shape[0], + self.number_of_per_atom_features, + 3, + 3, + device=data.atomic_numbers.device, + dtype=Iij.dtype, + ) + I = source.index_add(dim=0, index=pairlist_output.pair_indices[0], source=Iij) + A = source.index_add(dim=0, index=pairlist_output.pair_indices[0], source=Aij) + S = source.index_add(dim=0, index=pairlist_output.pair_indices[0], source=Sij) + + # equation 9 in TensorNet paper + # batch normalization + # NOTE: call init_norm differently + nomalized_tensor_I_A_S = self.batch_layer_normalization(tensor_norm(I + A + S)) + + nomalized_tensor_I_A_S = self.linears_scalar(nomalized_tensor_I_A_S).reshape( + -1, self.number_of_per_atom_features, 3 + ) + + # now equation 10 + # apply linear layers to I, A, S and return + I = ( + self.linears_tensor[0](I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + * nomalized_tensor_I_A_S[..., 0, None, None] + ) + A = ( + self.linears_tensor[1](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + * nomalized_tensor_I_A_S[..., 1, None, None] + ) + S = ( + self.linears_tensor[2](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + * nomalized_tensor_I_A_S[..., 2, None, None] + ) + X = I + A + S + return X, radial_feature_vector + + +class TensorNetInteraction(torch.nn.Module): def __init__( self, - radial_max_distance, - radial_min_distanc, - number_of_radial_basis_functions, + number_of_per_atom_features: int, + number_of_radial_basis_functions: int, + activation_function: nn.Module, + maximum_interaction_radius: float, + equivariance_invariance_group: str, ): - pass + """ + TensorNet interaction module for message passing and updating atomic features. + Parameters + ---------- + number_of_per_atom_features : int + Number of features per atom. + number_of_radial_basis_functions : int + Number of radial basis functions. + activation_function : nn.Module + Activation function class. + maximum_interaction_radius : float + Maximum interaction radius in nanometer. + equivariance_invariance_group : str + Equivariance invariance group, either "O(3)" or "SO(3)". + """ -class ANIInteraction(torch.nn.Module): - def __init__(self): - pass + super().__init__() + from modelforge.potential.utils import Dense - def forward(self): - pass + self.number_of_per_atom_features = number_of_per_atom_features + self.number_of_radial_basis_functions = number_of_radial_basis_functions + self.activation_function = activation_function + self.cutoff_module = CosineAttenuationFunction(maximum_interaction_radius) + self.mlp_scalar = nn.Sequential( + Dense( + number_of_radial_basis_functions, + number_of_per_atom_features, + bias=True, + activation_function=self.activation_function, + ), + Dense( + number_of_per_atom_features, + 2 * number_of_per_atom_features, + bias=True, + activation_function=self.activation_function, + ), + Dense( + 2 * number_of_per_atom_features, + 3 * number_of_per_atom_features, + bias=True, + activation_function=self.activation_function, + ), + ) + self.linear_layer = nn.Sequential( + *[ + Dense( + number_of_per_atom_features, number_of_per_atom_features, bias=False + ) + for _ in range(6) + ] + ) + self.equivariance_invariance_group = equivariance_invariance_group + self.reset_parameters() -@dataclass -class TensorNetNeuralNetworkData(NeuralNetworkData): - """ - A dataclass to structure the inputs for ANI neural network potentials, designed to - facilitate the efficient representation of atomic systems for energy computation and - property prediction. + def reset_parameters(self): + """ + Initialize neural network parameters of the interaction layer. + """ + for linear in self.mlp_scalar: + try: + linear.reset_parameters() + except AttributeError: + pass + for linear in self.linear_layer: + try: + linear.reset_parameters() + except AttributeError: + pass - Attributes - ---------- - pair_indices : torch.Tensor - A 2D tensor indicating the indices of atom pairs. Shape: [2, num_pairs]. - d_ij : torch.Tensor - A 1D tensor containing distances between each pair of atoms. Shape: [num_pairs, 1]. - r_ij : torch.Tensor - A 2D tensor representing displacement vectors between atom pairs. Shape: [num_pairs, 3]. - number_of_atoms : int - An integer indicating the number of atoms in the batch. - positions : torch.Tensor - A 2D tensor representing the XYZ coordinates of each atom. Shape: [num_atoms, 3]. - atom_index : torch.Tensor - A 1D tensor containing atomic numbers for each atom in the system(s). Shape: [num_atoms]. - atomic_subsystem_indices : torch.Tensor - A 1D tensor mapping each atom to its respective subsystem or molecule. Shape: [num_atoms]. - total_charge : torch.Tensor - An tensor with the total charge of each system or molecule. Shape: [num_systems]. - atomic_numbers : torch.Tensor - A 1D tensor containing the atomic numbers for atoms, used for identifying the atom types within the model. Shape: [num_atoms]. + def forward( + self, + X: torch.Tensor, + pair_indices: torch.Tensor, + d_ij: torch.Tensor, + radial_feature_vector: torch.Tensor, + atomic_charges: torch.Tensor, + ) -> torch.Tensor: + """ + Compute the output of the interaction layer and update tensor X. Updates + the tensor through message passing, scalar transformation, and tensor + decomposition ("Interaction and node update" in from section 3.2 in + TensorNet paper). X^(i) <- X^(i) + Delta X^(i) - """ + Parameters + ---------- + X : torch.Tensor + X tensor specifies pair-wise features of the atomic system. + pair_indices : torch.Tensor + A pair-wise index tensor specifying the corresponding atomic pairs. + d_ij : torch.Tensor + Atomic pair-wise distances. + radial_feature_vector : torch.Tensor + Radial feature vector calculated through + TensorNetRadialBasisFunction. + atomic_charges: torch.Tensor + Total charge q is a molecule-wise property. We transform it into an + atom-wise property, with all atoms belonging to the same molecule + being asqsigned the same charge q + (https://github.com/torchmd/torchmd-net/blob/6dea4b61e24de3e18921397866b7d9c5fd6b8bf1/torchmdnet/models/tensornet.py#L237) + + Returns + ------- + torch.Tensor + The updated X tensor. + """ + + # see equation 11 + C = self.cutoff_module(d_ij).view(-1, 1) + + # apply scalar MLP to radial feature vector and combine with cutoff + radial_feature_vector = self.mlp_scalar(radial_feature_vector) * C + + radial_feature_vector = radial_feature_vector.view( + radial_feature_vector.shape[0], self.number_of_per_atom_features, 3 + ) + X_shape = X.shape + feature_shape = (X_shape[0], X_shape[1], X_shape[2], X_shape[3]) + + X = X / (tensor_norm(X) + 1)[..., None, None] + I, A, S = decompose_tensor(X) + I = self.linear_layer[0](I.transpose(1, 3)).transpose(1, 3) + A = self.linear_layer[1](A.transpose(1, 3)).transpose(1, 3) + S = self.linear_layer[2](S.transpose(1, 3)).transpose(1, 3) + + Y = I + A + S + + Im = tensor_message_passing( + pair_indices, radial_feature_vector[..., 0, None, None], I, feature_shape + ) + Am = tensor_message_passing( + pair_indices, radial_feature_vector[..., 1, None, None], A, feature_shape + ) + Sm = tensor_message_passing( + pair_indices, radial_feature_vector[..., 2, None, None], S, feature_shape + ) + msg = Im + Am + Sm + + if self.equivariance_invariance_group == "O(3)": + A = torch.matmul(msg, Y) + B = torch.matmul(Y, msg) + I, A, S = decompose_tensor( + (1 + 0.1 * atomic_charges[..., None, None, None]) * (A + B) + ) + + if self.equivariance_invariance_group == "SO(3)": + B = torch.matmul(Y, msg) + I, A, S = decompose_tensor(2 * B) + + normp1 = (tensor_norm(I + A + S) + 1)[..., None, None] + I, A, S = I / normp1, A / normp1, S / normp1 + I = self.linear_layer[3](I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + A = self.linear_layer[4](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + S = self.linear_layer[5](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + dX = I + A + S + X = ( + X + + dX + + (1 + 0.1 * atomic_charges[..., None, None, None]) + * torch.matrix_power(dX, 2) + ) + return X diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index b40c5044..c1cc665a 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -1,30 +1,16 @@ +""" +Utility functions for neural network potentials. +""" + import math -from dataclasses import dataclass, field -from typing import Any, Callable, Optional, Tuple, NamedTuple, Type +from dataclasses import dataclass +from typing import Callable, Optional -import numpy as np import torch import torch.nn as nn -from loguru import logger as log from openff.units import unit -from pint import Quantity -from typing import Union -from modelforge.dataset.dataset import NNPInput - -@dataclass -class NeuralNetworkData: - pair_indices: torch.Tensor - d_ij: torch.Tensor - r_ij: torch.Tensor - atomic_numbers: torch.Tensor - number_of_atoms: int - positions: torch.Tensor - atomic_subsystem_indices: torch.Tensor - total_charge: torch.Tensor - - -import torch +from modelforge.utils.prop import NNPInput @dataclass(frozen=False) @@ -75,10 +61,8 @@ def to( def shared_config_prior(): - from modelforge.utils.io import import_ - tune = import_("ray").tune - # from ray import tune + from ray import tune return { "lr": tune.loguniform(1e-5, 1e-1), @@ -87,123 +71,66 @@ def shared_config_prior(): } -def triple_by_molecule( - atom_pairs: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Input: indices for pairs of atoms that are close to each other. - each pair only appear once, i.e. only one of the pairs (1, 2) and - (2, 1) exists. - - NOTE: this function is taken from https://github.com/aiqm/torchani/blob/17204c6dccf6210753bc8c0ca4c92278b60719c9/torchani/aev.py - with little modifications. - """ - - def cumsum_from_zero(input_: torch.Tensor) -> torch.Tensor: - cumsum = torch.zeros_like(input_) - torch.cumsum(input_[:-1], dim=0, out=cumsum[1:]) - return cumsum - - # convert representation from pair to central-others - ai1 = atom_pairs.view(-1) - - # Note, torch.sort doesn't guarantee stable sort by default. - # This means that the order of rev_indices is not guaranteed when there are "ties" - # (i.e., identical values in the input tensor). - # Stable sort is more expensive and ultimately unnecessary, so we will not use it here, - # but it does mean that vector-wise comparison of the outputs of this function may be - # inconsistent for the same input, and thus tests must be designed accordingly. - - sorted_ai1, rev_indices = ai1.sort() - - # sort and compute unique key - uniqued_central_atom_index, counts = torch.unique_consecutive( - sorted_ai1, return_inverse=False, return_counts=True - ) - - # compute central_atom_index - pair_sizes = torch.div(counts * (counts - 1), 2, rounding_mode="trunc") - pair_indices = torch.repeat_interleave(pair_sizes) - central_atom_index = uniqued_central_atom_index.index_select(0, pair_indices) +import torch.nn.functional as F +from torch.nn.init import xavier_uniform_, zeros_ - # do local combinations within unique key, assuming sorted - m = counts.max().item() if counts.numel() > 0 else 0 - n = pair_sizes.shape[0] - intra_pair_indices = ( - torch.tril_indices(m, m, -1, device=ai1.device).unsqueeze(1).expand(-1, n, -1) - ) - mask = ( - torch.arange(intra_pair_indices.shape[2], device=ai1.device) - < pair_sizes.unsqueeze(1) - ).flatten() - sorted_local_index12 = intra_pair_indices.flatten(1, 2)[:, mask] - sorted_local_index12 += cumsum_from_zero(counts).index_select(0, pair_indices) - # unsort result from last part - local_index12 = rev_indices[sorted_local_index12] +class Dense(nn.Linear): + """ + Fully connected linear layer with activation function. - # compute mapping between representation of central-other to pair - n = atom_pairs.shape[1] - sign12 = ((local_index12 < n).to(torch.int8) * 2) - 1 - return central_atom_index, local_index12 % n, sign12 + forward(input) + Forward pass of the layer. + """ -class Embedding(nn.Module): - def __init__(self, num_embeddings: int, embedding_dim: int): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + activation_function: nn.Module = nn.Identity(), + ): """ - Initialize the embedding module. + A linear or non-linear transformation Parameters ---------- - num_embeddings: int - embedding_dim : int - Dimensionality of the embedding. + in_features : int + Number of input features. + out_features : int + Number of output features. + bias : bool, optional + If set to False, the layer will not learn an additive bias. Default is True. + activation_function : Type[torch.nn.Module] , optional + Activation function to be applied. Default is nn.Identity(), which applies the identity function + and makes this a linear transformation. """ - super().__init__() - self.embedding = nn.Embedding(num_embeddings, embedding_dim) - @property - def data(self): - return self.embedding.weight.data + super().__init__(in_features, out_features, bias) - @data.setter - def data(self, data): - self.embedding.weight.data = data + self.activation_function = activation_function - @property - def embedding_dim(self): - """ - Get the dimensionality of the embedding. - - Returns - ------- - int - The dimensionality of the embedding. - """ - return self.embedding.embedding_dim - - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, input: torch.Tensor): """ - Embeddes the pr3ovided 1D tensor using the embedding layer. + Forward pass of the layer. Parameters ---------- - x : torch.Tensor - 1D tensor to be embedded. + input : torch.Tensor + Input tensor. Returns ------- torch.Tensor - with shape (num_embeddings, embedding_dim) - """ - - return self.embedding(x) - + Output tensor after applying the linear transformation and activation function. -import torch.nn.functional as F -from torch.nn.init import xavier_uniform_, zeros_ + """ + y = F.linear(input, self.weight, self.bias) + return self.activation_function(y) -class Dense(nn.Linear): +class DenseWithCustomDist(nn.Linear): """ Fully connected linear layer with activation function. @@ -228,15 +155,12 @@ def __init__( in_features: int, out_features: int, bias: bool = True, - activation: Optional[ - Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]] - ] = None, + activation_function: nn.Module = nn.Identity(), weight_init: Callable = xavier_uniform_, bias_init: Callable = zeros_, ): """ - __init__ _summary_ - + A linear or non-linear transformation Parameters ---------- @@ -246,22 +170,23 @@ def __init__( Number of output features. bias : bool, optional If set to False, the layer will not learn an additive bias. Default is True. - activation : nn.Module or Callable[[torch.Tensor], torch.Tensor], optional - Activation function to be applied. Default is None, which applies the identity function and makes this a linear transformation. + activation_function : nn.Module , optional + Activation function to be applied. Default is nn.Identity(), which applies the identity function + and makes this a linear ransformation. weight_init : Callable, optional Callable to initialize the weights. Default is xavier_uniform_. bias_init : Callable, optional Function to initialize the bias. Default is zeros_. """ - # NOTE: these two variables need to come before the initi + # NOTE: these two variables need to come before the init self.weight_init_distribution = weight_init self.bias_init_distribution = bias_init super().__init__( in_features, out_features, bias - ) # NOTE: the `reseet_paramters` method is called in the super class + ) # NOTE: the `reset_paramters` method is called in the super class - self.activation = activation or nn.Identity() + self.activation_function = activation_function def reset_parameters(self): """ @@ -287,232 +212,43 @@ def forward(self, input: torch.Tensor): """ y = F.linear(input, self.weight, self.bias) - return self.activation(y) - - -from openff.units import unit - - -class CosineCutoff(nn.Module): - def __init__(self, cutoff: unit.Quantity): - """ - Behler-style cosine cutoff module. - NOTE: The cutoff is converted to nanometer and the input MUST be in nanomter too. - - Parameters: - ---------- - cutoff: unit.Quantity - The cutoff distance. - - """ - super().__init__() - cutoff = cutoff.to(unit.nanometer).m - self.register_buffer("cutoff", torch.tensor([cutoff])) - - def forward(self, d_ij: torch.Tensor): - """ - Compute the cosine cutoff for a distance tensor. - NOTE: the cutoff function doesn't care about units as long as they are consisten, - - Parameters - ---------- - d_ij : Tensor - Pairwise distance tensor in nanometer. Shape: [n_pairs, 1] - - Returns - ------- - Tensor - Cosine cutoff tensor. Shape: [n_pairs, 1] - """ - # Compute values of cutoff function - input_cut = 0.5 * ( - torch.cos(d_ij * np.pi / self.cutoff) + 1.0 - ) # NOTE: ANI adds 0.5 instead of 1. - # Remove contributions beyond the cutoff radius - input_cut *= (d_ij < self.cutoff).float() - return input_cut + return self.activation_function(y) from typing import Dict +from openff.units import unit + class ShiftedSoftplus(nn.Module): def __init__(self): super().__init__() - import math self.log_2 = math.log(2.0) def forward(self, x: torch.Tensor): - """Compute shifted soft-plus activation function. + """ + Compute shifted soft-plus activation function. - y = \ln\left(1 + e^{-x}\right) - \ln(2) + The shifted soft-plus activation function is defined as: + y = ln(1 + exp(-x)) - ln(2) Parameters: ----------- - x:torch.Tensor - input tensor + x : torch.Tensor + Input tensor. Returns: ----------- - torch.Tensor: shifted soft-plus of input. + torch.Tensor + Shifted soft-plus of the input. """ - from torch.nn import functional - - return functional.softplus(x) - self.log_2 - - -class AngularSymmetryFunction(nn.Module): - """ - Initialize AngularSymmetryFunction module. - - """ - - def __init__( - self, - max_distance: unit.Quantity, - min_distance: unit.Quantity, - number_of_gaussians_for_asf: int = 8, - angle_sections: int = 4, - trainable: bool = False, - dtype: Optional[torch.dtype] = None, - ) -> None: - """ - Parameters - ---- - number_of_gaussian: Number of gaussian functions to use for angular symmetry function. - angular_cutoff: Cutoff distance for angular symmetry function. - angular_start: Starting distance for angular symmetry function. - ani_style: Whether to use ANI symmetry function style. - """ - - super().__init__() - from loguru import logger as log - - self.number_of_gaussians_asf = number_of_gaussians_for_asf - self.angular_cutoff = max_distance - self.cosine_cutoff = CosineCutoff(self.angular_cutoff) - _unitless_angular_cutoff = max_distance.to(unit.nanometer).m - self.angular_start = min_distance - _unitless_angular_start = min_distance.to(unit.nanometer).m - - # save constants - EtaA = angular_eta = 12.5 * 100 # FIXME hardcoded eta - Zeta = 14.1000 # FIXME hardcoded zeta - - if trainable: - self.EtaA = torch.tensor([EtaA], dtype=dtype) - self.Zeta = torch.tensor([Zeta], dtype=dtype) - self.Rca = torch.tensor([_unitless_angular_cutoff], dtype=dtype) - - else: - self.register_buffer("EtaA", torch.tensor([EtaA], dtype=dtype)) - self.register_buffer("Zeta", torch.tensor([Zeta], dtype=dtype)) - self.register_buffer( - "Rca", torch.tensor([_unitless_angular_cutoff], dtype=dtype) - ) - - # =============== - # # calculate shifts - # =============== - import math - - # ShfZ - angle_start = math.pi / (2 * angle_sections) - ShfZ = (torch.linspace(0, math.pi, angle_sections + 1) + angle_start)[:-1] - # ShfA - ShfA = torch.linspace( - _unitless_angular_start, - _unitless_angular_cutoff, - number_of_gaussians_for_asf + 1, - )[:-1] - # register shifts - if trainable: - self.ShfZ = ShfZ - self.ShfA = ShfA - else: - self.register_buffer("ShfZ", ShfZ) - self.register_buffer("ShfA", ShfA) - - # The length of angular subaev of a single species - self.angular_sublength = self.ShfA.numel() * self.ShfZ.numel() - - def forward(self, r_ij: torch.Tensor) -> torch.Tensor: - # calculate the angular sub aev - sub_aev = self.compute_angular_sub_aev(r_ij) - return sub_aev - - def compute_angular_sub_aev(self, vectors12: torch.Tensor) -> torch.Tensor: - """Compute the angular subAEV terms of the center atom given neighbor pairs. - - This correspond to equation (4) in the ANI paper. This function just - compute the terms. The sum in the equation is not computed. - The input tensor have shape (conformations, atoms, N), where N - is the number of neighbor atom pairs within the cutoff radius and - output tensor should have shape - (conformations, atoms, ``self.angular_sublength()``) - - """ - vectors12 = vectors12.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) - distances12 = vectors12.norm(2, dim=-5) - - # 0.95 is multiplied to the cos values to prevent acos from - # returning NaN. - cos_angles = 0.95 * torch.nn.functional.cosine_similarity( - vectors12[0], vectors12[1], dim=-5 - ) - angles = torch.acos(cos_angles) - fcj12 = self.cosine_cutoff(distances12) - factor1 = ((1 + torch.cos(angles - self.ShfZ)) / 2) ** self.Zeta - factor2 = torch.exp( - -self.EtaA * (distances12.sum(0) / 2 - self.ShfA) ** 2 - ).unsqueeze(-1) - factor2 = factor2.squeeze(4).squeeze(3) - ret = 2 * factor1 * factor2 * fcj12.prod(0) - # At this point, ret now have shape - # (conformations, atoms, N, ?, ?, ?, ?) where ? depend on constants. - # We then should flat the last 4 dimensions to view the subAEV as one - # dimension vector - return ret.flatten(start_dim=-4) - - -from abc import ABC, abstractmethod - - -class RadialBasisFunctionCore(nn.Module, ABC): - - def __init__(self, number_of_radial_basis_functions): - super().__init__() - self.number_of_radial_basis_functions = number_of_radial_basis_functions - - @abstractmethod - def forward(self, nondimensionalized_distances: torch.Tensor) -> torch.Tensor: - """ - Parameters - --------- - nondimensionalized_distances: torch.Tensor, shape [number_of_pairs, number_of_radial_basis_functions] - Nondimensional quantities that depend on pairwise distances. - - Returns - --------- - torch.Tensor, shape [number_of_pairs, number_of_radial_basis_functions] - """ - pass - - -class GaussianRadialBasisFunctionCore(RadialBasisFunctionCore): - - def forward(self, nondimensionalized_distances: torch.Tensor) -> torch.Tensor: - assert nondimensionalized_distances.ndim == 2 - assert ( - nondimensionalized_distances.shape[1] - == self.number_of_radial_basis_functions - ) - return torch.exp(-(nondimensionalized_distances**2)) + return F.softplus(x) - self.log_2 + class ExponentialBernsteinPolynomialsCore(RadialBasisFunctionCore): """ Taken from SpookyNet. @@ -575,421 +311,7 @@ def forward(self, nondimensionalized_distances: torch.Tensor) -> torch.Tensor: ) return torch.exp(x) - - -class RadialBasisFunction(nn.Module, ABC): - - def __init__( - self, - radial_basis_function: RadialBasisFunctionCore, - dtype: torch.dtype, - prefactor: float = 1.0, - trainable_prefactor: bool = False, - ): - super().__init__() - if trainable_prefactor: - self.prefactor = nn.Parameter(torch.tensor([prefactor], dtype=dtype)) - else: - self.register_buffer("prefactor", torch.tensor([prefactor], dtype=dtype)) - self.radial_basis_function = radial_basis_function - - @abstractmethod - def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: - """ - Parameters - --------- - distances: torch.Tensor, shape [number_of_pairs, 1] - Distances between atoms in each pair in nanometers. - - Returns - --------- - torch.Tensor, shape [number_of_pairs, number_of_radial_basis_functions] - Nondimensional quantities computed from the distances. - """ - pass - - def forward(self, distances: torch.Tensor) -> torch.Tensor: - """ - The input distances have implicit units of nanometers by the convention of modelforge. This function applies - nondimensionalization transformations on the distances and passes the dimensionless result to - RadialBasisFunctionCore. There can be several nondimsionalization transformations, corresponding to each element - along the number_of_radial_basis_functions axis in the output. - - Parameters - --------- - distances: torch.Tensor, shape [number_of_pairs, 1] - Distances between atoms in each pair in nanometers. - - Returns - --------- - torch.Tensor, shape [number_of_pairs, number_of_radial_basis_functions] - Output of radial basis functions. - """ - nondimensionalized_distances = self.nondimensionalize_distances(distances) - return self.prefactor * self.radial_basis_function(nondimensionalized_distances) - - -class GaussianRadialBasisFunctionWithScaling(RadialBasisFunction): - """ - Shifts inputs by a set of centers and scales by a set of scale factors before passing into the standard Gaussian. - """ - - def __init__( - self, - number_of_radial_basis_functions: int, - max_distance: unit.Quantity, - min_distance: unit.Quantity = 0.0 * unit.nanometer, - dtype: Optional[torch.dtype] = None, - prefactor: float = 1.0, - trainable_prefactor: bool = False, - trainable_centers_and_scale_factors: bool = False, - ): - """ - Parameters - --------- - number_of_radial_basis_functions: int - Number of radial basis functions to use. - max_distance: unit.Quantity - Maximum distance to consider for symmetry functions. - min_distance: unit.Quantity - Minimum distance to consider. - dtype: torch.dtype, default None - Data type for computations. - prefactor: float - Scalar factor by which to multiply output of radial basis functions. - trainable_prefactor: bool, default False - Whether prefactor is trainable - trainable_centers_and_scale_factors: bool, default False - Whether centers and scale factors are trainable. - """ - - super().__init__( - GaussianRadialBasisFunctionCore(number_of_radial_basis_functions), - dtype, - prefactor, - trainable_prefactor, - ) - self.number_of_radial_basis_functions = number_of_radial_basis_functions - self.dtype = dtype - self.trainable_centers_and_scale_factors = trainable_centers_and_scale_factors - # convert to nanometer - _max_distance_in_nanometer = max_distance.to(unit.nanometer).m - _min_distance_in_nanometer = min_distance.to(unit.nanometer).m - - # calculate radial basis centers - radial_basis_centers = self.calculate_radial_basis_centers( - self.number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - self.dtype, - ) - # calculate scale factors - radial_scale_factor = self.calculate_radial_scale_factor( - self.number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - self.dtype, - ) - - # either add as parameters or register buffers - if self.trainable_centers_and_scale_factors: - self.radial_basis_centers = radial_basis_centers - self.radial_scale_factor = radial_scale_factor - else: - self.register_buffer("radial_basis_centers", radial_basis_centers) - self.register_buffer("radial_scale_factor", radial_scale_factor) - - @staticmethod - @abstractmethod - def calculate_radial_basis_centers( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, - ): - """ - NOTE: centers have units of nanometers - """ - pass - - @staticmethod - @abstractmethod - def calculate_radial_scale_factor( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, - ): - """ - NOTE: radial scale factors have units of nanometers - """ - pass - - def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: - # Here, self.radial_scale_factor is interpreted as sqrt(2) times the standard deviation of the Gaussian. - diff = distances - self.radial_basis_centers - return diff / self.radial_scale_factor - - -class SchnetRadialBasisFunction(GaussianRadialBasisFunctionWithScaling): - """ - Implementation of the radial basis function as used by the SchNet neural network - """ - - def __init__( - self, - number_of_radial_basis_functions: int, - max_distance: unit.Quantity, - min_distance: unit.Quantity = 0.0 * unit.nanometer, - dtype: Optional[torch.dtype] = None, - trainable_centers_and_scale_factors: bool = False, - ): - """ - Parameters - --------- - number_of_radial_basis_functions: int - Number of radial basis functions to use. - max_distance: unit.Quantity - Maximum distance to consider for symmetry functions. - min_distance: unit.Quantity - Minimum distance to consider. - dtype: torch.dtype, default None - Data type for computations. - trainable_centers_and_scale_factors: bool, default False - Whether centers and scale factors are trainable. - """ - super().__init__( - number_of_radial_basis_functions, - max_distance, - min_distance, - dtype, - trainable_prefactor=False, - trainable_centers_and_scale_factors=trainable_centers_and_scale_factors, - ) - - @staticmethod - def calculate_radial_basis_centers( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, - ): - return torch.linspace( - _min_distance_in_nanometer, - _max_distance_in_nanometer, - number_of_radial_basis_functions, - dtype=dtype, - ) - - @staticmethod - def calculate_radial_scale_factor( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, - ): - scale_factors = torch.linspace( - _min_distance_in_nanometer, - _max_distance_in_nanometer, - number_of_radial_basis_functions, - dtype=dtype, - ) - - widths = torch.abs(scale_factors[1] - scale_factors[0]) * torch.ones_like( - scale_factors - ) - - scale_factors = math.sqrt(2) * widths - return scale_factors - - -class AniRadialBasisFunction(GaussianRadialBasisFunctionWithScaling): - """ - Implementation of the radial basis function as used by the ANI neural network - """ - - def __init__( - self, - number_of_radial_basis_functions, - max_distance: unit.Quantity, - min_distance: unit.Quantity = 0.0 * unit.nanometer, - dtype: torch.dtype = torch.float32, - trainable_centers_and_scale_factors: bool = False, - ): - """ - Parameters - --------- - number_of_radial_basis_functions: int - Number of radial basis functions to use. - max_distance: unit.Quantity - Maximum distance to consider for symmetry functions. - min_distance: unit.Quantity - Minimum distance to consider. - dtype: torch.dtype, default torch.float32 - Data type for computations. - trainable_centers_and_scale_factors: bool, default False - Whether centers and scale factors are trainable. - """ - super().__init__( - number_of_radial_basis_functions, - max_distance, - min_distance, - dtype, - prefactor=0.25, - trainable_prefactor=False, - trainable_centers_and_scale_factors=trainable_centers_and_scale_factors, - ) - - @staticmethod - def calculate_radial_basis_centers( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, - ): - centers = torch.linspace( - _min_distance_in_nanometer, - _max_distance_in_nanometer, - number_of_radial_basis_functions + 1, - dtype=dtype, - )[:-1] - return centers - - @staticmethod - def calculate_radial_scale_factor( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, - ): - # ANI uses a predefined scaling factor - scale_factors = torch.full( - (number_of_radial_basis_functions,), (19.7 * 100) ** -0.5 - ) - return scale_factors - - -class PhysNetRadialBasisFunction(RadialBasisFunction): - """ - Implementation of the radial basis function as used by the PysNet neural network - """ - - def __init__( - self, - number_of_radial_basis_functions: int, - max_distance: unit.Quantity, - min_distance: unit.Quantity = 0.0 * unit.nanometer, - alpha: unit.Quantity = 1.0 * unit.angstrom, - dtype: torch.dtype = torch.float32, - trainable_centers_and_scale_factors: bool = False, - ): - """ - Parameters - ---------- - number_of_radial_basis_functions : int - Number of radial basis functions to use. - max_distance : unit.Quantity - Maximum distance to consider for symmetry functions. - min_distance : unit.Quantity - Minimum distance to consider, by default 0.0 * unit.nanometer. - alpha: unit.Quantity - Scale factor used to nondimensionalize the input to all exp calls. The PhysNet paper implicitly divides by 1 - Angstrom within exponentials. Note that this is distinct from the unitless scale factors used outside the - exp but within the Gaussian. - dtype : torch.dtype, optional - Data type for computations, by default torch.float32. - trainable_centers_and_scale_factors : bool, optional - Whether centers and scale factors are trainable, by default False. - """ - - super().__init__( - GaussianRadialBasisFunctionCore(number_of_radial_basis_functions), - trainable_prefactor=False, - dtype=dtype, - ) - self._min_distance_in_nanometer = min_distance.to(unit.nanometer).m - self._alpha_in_nanometer = alpha.to(unit.nanometer).m - radial_basis_centers = self.calculate_radial_basis_centers( - number_of_radial_basis_functions, - max_distance, - min_distance, - alpha, - dtype, - ) - # calculate scale factors - radial_scale_factor = self.calculate_radial_scale_factor( - number_of_radial_basis_functions, - max_distance, - min_distance, - alpha, - dtype, - ) - - if trainable_centers_and_scale_factors: - self.radial_basis_centers = radial_basis_centers - self.radial_scale_factor = radial_scale_factor - else: - self.register_buffer("radial_basis_centers", radial_basis_centers) - self.register_buffer("radial_scale_factor", radial_scale_factor) - - @staticmethod - def calculate_radial_basis_centers( - number_of_radial_basis_functions, - max_distance, - min_distance, - alpha, - dtype, - ): - # initialize centers according to the default values in PhysNet - # (see mu_k in Figure 2 caption of https://pubs.acs.org/doi/10.1021/acs.jctc.9b00181) - # NOTE: Unlike GaussianRadialBasisFunctionWithScaling, the centers are unitless. - - start_value = torch.exp( - torch.scalar_tensor( - ((-max_distance + min_distance) / alpha).to("").m, - dtype=dtype, - ) - ) - centers = torch.linspace( - start_value, 1, number_of_radial_basis_functions, dtype=dtype - ) - return centers - - @staticmethod - def calculate_radial_scale_factor( - number_of_radial_basis_functions, - max_distance, - min_distance, - alpha, - dtype, - ): - # initialize according to the default values in PhysNet (see beta_k in Figure 2 caption) - # NOTES: - # - Unlike GaussianRadialBasisFunctionWithScaling, the scale factors are unitless. - # - Each element of radial_square_factor here is the reciprocal of the square root of beta_k in the - # Eq. 7 of the PhysNet paper. This way, it is consistent with the sqrt(2) * standard deviation interpretation - # of radial_scale_factor in GaussianRadialBasisFunctionWithScaling - return torch.full( - (number_of_radial_basis_functions,), - (2 * (1 - math.exp(((-max_distance + min_distance) / alpha).to("").m))) - / number_of_radial_basis_functions, - dtype=dtype, - ) - - def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: - # Transformation within the outer exp of PhysNet Eq. 7 - # NOTE: the PhysNet paper implicitly multiplies by 1/Angstrom within the inner exp but distances are in - # nanometers, so we multiply by 10/nanometer - - return ( - torch.exp( - (-distances + self._min_distance_in_nanometer) - / self._alpha_in_nanometer - ) - - self.radial_basis_centers - ) / self.radial_scale_factor - - + class ExponentialBernsteinRadialBasisFunction(RadialBasisFunction): def __init__(self, @@ -1073,36 +395,89 @@ def pair_list( return pair_indices.to(device) +from openff.units import unit + + +def convert_str_to_unit_in_dataset_statistics( + dataset_statistic: Dict[str, Dict[str, str]] +) -> Dict[str, Dict[str, unit.Quantity]]: + for key, value in dataset_statistic.items(): + for sub_key, sub_value in value.items(): + dataset_statistic[key][sub_key] = unit.Quantity(sub_value) + return dataset_statistic + + +def remove_units_from_dataset_statistics( + dataset_statistic: Dict[str, Dict[str, unit.Quantity]] +) -> Dict[str, Dict[str, float]]: + from openff.units import unit + + from modelforge.utils.units import chem_context + + dataset_statistic_without_units = {} + for key, value in dataset_statistic.items(): + dataset_statistic_without_units[key] = {} + for sub_key, sub_value in value.items(): + dataset_statistic_without_units[key][sub_key] = ( + unit.Quantity(sub_value).to(unit.kilojoule_per_mole, "chem").m + ) + return dataset_statistic_without_units + + +def read_dataset_statistics( + dataset_statistic_filename: str, remove_units: bool = False +): + import toml + + # read file + dataset_statistic = toml.load(dataset_statistic_filename) + # convert to float (to kJ/mol and then strip the units) + # dataset statistic is a Dict[str, Dict[str, unit.Quantity]], we need to strip the units + if remove_units: + return remove_units_from_dataset_statistics(dataset_statistic=dataset_statistic) + else: + return dataset_statistic + + def scatter_softmax( src: torch.Tensor, index: torch.Tensor, dim: int, - dim_size: Optional[int] = None, - device: Optional[torch.device] = None, + dim_size: int, ) -> torch.Tensor: """ - Softmax operation over all values in :attr:`src` tensor that share indices - specified in the :attr:`index` tensor along a given axis :attr:`dim`. + Computes the softmax operation over values in the `src` tensor that share indices specified in the `index` tensor + along a given axis `dim`. - For one-dimensional tensors, the operation computes + For one-dimensional tensors, the operation computes: .. math:: - \mathrm{out}_i = {\textrm{softmax}(\mathrm{src})}_i = - \frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)} - - where :math:`\sum_j` is over :math:`j` such that - :math:`\mathrm{index}_j = i`. + \text{out}_i = \text{softmax}(\text{src})_i = + \frac{\exp(\text{src}_i)}{\sum_j \exp(\text{src}_j)} - Args: - src (Tensor): The source tensor. - index (LongTensor): The indices of elements to scatter. - dim (int, optional): The axis along which to index. - (default: :obj:`-1`) - dim_size: The number of classes, i.e. the number of unique indices in `index`. + where the summation :math:`\sum_j` is over all :math:`j` such that :math:`\text{index}_j = i`. - :rtype: :class:`Tensor` + Parameters + ---------- + src : Tensor + The source tensor containing the values to which the softmax operation will be applied. + index : LongTensor + The indices of elements to scatter, determining which elements in `src` are grouped together for the + softmax calculation. + dim : int + The axis along which to index. Default is `-1`. + dim_size : int + The number of classes, i.e., the number of unique indices in `index`. + + Returns + ------- + Tensor + A tensor where the softmax operation has been applied along the specified dimension. - Adapted from: https://github.com/rusty1s/pytorch_scatter/blob/c31915e1c4ceb27b2e7248d21576f685dc45dd01/torch_scatter/composite/softmax.py + Notes + ----- + This implementation is adapted from the following source: + `pytorch_scatter `_. """ if not torch.is_floating_point(src): raise ValueError( @@ -1120,7 +495,7 @@ def scatter_softmax( for (other_dim, other_dim_size) in enumerate(src.shape) ] index = index.to(torch.int64) - zeros = torch.zeros(out_shape, dtype=src.dtype, device=device) + zeros = torch.zeros(out_shape, dtype=src.dtype, device=src.device) max_value_per_index = zeros.scatter_reduce( dim, index, src, "amax", include_self=False ) @@ -1129,9 +504,24 @@ def scatter_softmax( recentered_scores = src - max_per_src_element recentered_scores_exp = recentered_scores.exp() - sum_per_index = torch.zeros(out_shape, dtype=src.dtype, device=device).scatter_add( - dim, index, recentered_scores_exp - ) + sum_per_index = torch.zeros( + out_shape, dtype=src.dtype, device=src.device + ).scatter_add(dim, index, recentered_scores_exp) normalizing_constants = sum_per_index.gather(dim, index) return recentered_scores_exp.div(normalizing_constants) + + +ACTIVATION_FUNCTIONS = { + "ReLU": nn.ReLU, + "CeLU": nn.CELU, + "GeLU": nn.GELU, + "Sigmoid": nn.Sigmoid, + "Softmax": nn.Softmax, + "ShiftedSoftplus": ShiftedSoftplus, + "SiLU": nn.SiLU, + "Tanh": nn.Tanh, + "LeakyReLU": nn.LeakyReLU, + "ELU": nn.ELU, + # Add more activation functions as needed +} diff --git a/modelforge/tests/__init__.py b/modelforge/tests/__init__.py index 2f065e73..1971314e 100644 --- a/modelforge/tests/__init__.py +++ b/modelforge/tests/__init__.py @@ -1,4 +1,4 @@ """ -Empty init file in case you choose a package besides PyTest such as Nose which may look -for such a file. +Empty init file in case you choose a package besides PyTest such as Nose which +may look for such a file. """ diff --git a/modelforge/tests/conftest.py b/modelforge/tests/conftest.py index 7e8e475f..c0572ad4 100644 --- a/modelforge/tests/conftest.py +++ b/modelforge/tests/conftest.py @@ -1,9 +1,10 @@ -import torch +from dataclasses import dataclass +from typing import Dict, Optional + import pytest -from modelforge.dataset import DataModule +import torch -from typing import Optional, Dict -from dataclasses import dataclass +from modelforge.dataset import DataModule # let us setup a few pytest options @@ -37,39 +38,6 @@ def create_datamodule(**kwargs): return create_datamodule -from modelforge.dataset.utils import ( - FirstComeFirstServeSplittingStrategy, - SplittingStrategy, -) - - -def initialize_datamodule( - dataset_name: str, - version_select: str = "nc_1000_v0", - batch_size: int = 64, - splitting_strategy: SplittingStrategy = FirstComeFirstServeSplittingStrategy(), - remove_self_energies: bool = True, - regression_ase: bool = False, - regenerate_dataset_statistic: bool = False, -) -> DataModule: - """ - Initialize a dataset for a given mode. - """ - - data_module = DataModule( - dataset_name, - splitting_strategy=splitting_strategy, - batch_size=batch_size, - version_select=version_select, - remove_self_energies=remove_self_energies, - regression_ase=regression_ase, - regenerate_dataset_statistic=regenerate_dataset_statistic, - ) - data_module.prepare_data() - data_module.setup() - return data_module - - # dataset fixture @pytest.fixture def dataset_factory(): @@ -79,83 +47,37 @@ def create_dataset(**kwargs): return create_dataset -from modelforge.dataset.dataset import DatasetFactory, TorchDataset -from modelforge.dataset import _ImplementedDatasets - - -def single_batch(batch_size: int = 64, dataset_name="QM9"): - """ - Utility function to create a single batch of data for testing. - """ - data_module = initialize_datamodule( - dataset_name=dataset_name, - batch_size=batch_size, - version_select="nc_1000_v0", - ) - return next(iter(data_module.train_dataloader(shuffle=False))) - - -@pytest.fixture(scope="session") -def single_batch_with_batchsize_64(): - """ - Utility fixture to create a single batch of data for testing. - """ - return single_batch(batch_size=64) - - -@pytest.fixture(scope="session") -def single_batch_with_batchsize_1(): - """ - Utility fixture to create a single batch of data for testing. - """ - return single_batch(batch_size=1) - - -@pytest.fixture(scope="session") -def single_batch_with_batchsize_2_with_force(): - """ - Utility fixture to create a single batch of data for testing. - """ - return single_batch(batch_size=2, dataset_name="PHALKETHOH") +from modelforge.dataset.dataset import ( + initialize_datamodule, + initialize_dataset, + single_batch, +) @pytest.fixture(scope="session") -def single_batch_with_batchsize_16_with_force(): +def single_batch_with_batchsize(): """ Utility fixture to create a single batch of data for testing. """ - return single_batch(batch_size=16, dataset_name="PHALKETHOH") - - -def initialize_dataset( - dataset_name: str, - local_cache_dir: str, - versions_select: str = "nc_1000_v0", - force_download: bool = False, -) -> DataModule: - """ - Initialize a dataset for a given mode. - """ - - factory = DatasetFactory() - data = _ImplementedDatasets.get_dataset_class(dataset_name)( - local_cache_dir=local_cache_dir, - version_select=versions_select, - force_download=force_download, - ) - dataset = factory.create_dataset(data) - - return dataset + def _create_single_batch(batch_size: int, dataset_name: str, local_cache_dir: str): + return single_batch( + batch_size=batch_size, + dataset_name=dataset_name, + local_cache_dir=local_cache_dir, + ) -@pytest.fixture(scope="session") -def prep_temp_dir(tmp_path_factory): - import uuid + return _create_single_batch - filename = str(uuid.uuid4()) - tmp_path_factory.mktemp(f"dataset_test/") - return f"dataset_test" +# @pytest.fixture(scope="session") +# def prep_temp_dir(tmp_path_factory): +# import uuid +# +# filename = str(uuid.uuid4()) +# +# tmp_path_factory.mktemp(f"dataset_test/") +# return f"dataset_test" @dataclass @@ -170,7 +92,6 @@ class DataSetContainer: from modelforge.dataset import _ImplementedDatasets - dataset_container: Dict[str, DataSetContainer] = { "QM9": DataSetContainer( name="QM9", @@ -276,7 +197,7 @@ def equivariance_utils(): # helper functions # ----------------------------------------------------------- # -from modelforge.dataset.dataset import BatchData +from modelforge.utils.prop import BatchData, NNPInput @pytest.fixture @@ -288,7 +209,7 @@ def methane() -> BatchData: ------- BatchData """ - from modelforge.potential.utils import Metadata, NNPInput, BatchData + from modelforge.potential.utils import BatchData, Metadata atomic_numbers = torch.tensor([6, 1, 1, 1, 1], dtype=torch.int64) positions = ( @@ -311,7 +232,7 @@ def methane() -> BatchData: atomic_numbers=atomic_numbers, positions=positions, atomic_subsystem_indices=atomic_subsystem_indices, - total_charge=torch.tensor([0.0]), + per_system_total_charge=torch.tensor([0.0]), ), Metadata( E=E, @@ -322,9 +243,10 @@ def methane() -> BatchData: ) -import torch import math +import torch + def generate_uniform_quaternion(u=None): """ @@ -366,7 +288,7 @@ def generate_uniform_quaternion(u=None): def rotation_matrix_from_quaternion(quaternion): """Compute a 3x3 rotation matrix from a given quaternion (4-vector). - Adapted from the numpy implementation in openmm-tools + Adapted from the numpy implementation in modelforgeopenmm-tools https://github.com/choderalab/openmmtools/blob/main/openmmtools/mcmc.py @@ -438,11 +360,12 @@ def apply_rotation_matrix(coordinates, rotation_matrix, use_center_of_mass=True) if use_center_of_mass: coordinates_com = torch.mean(coordinates, 0) else: - coordinates_com = torch.zeros(3) + coordinates_com = torch.zeros(3).to(coordinates.device, coordinates.dtype) coordinates_proposed = ( torch.matmul( - rotation_matrix, (coordinates - coordinates_com).transpose(0, -1) + rotation_matrix.to(coordinates.device, coordinates.dtype), + (coordinates - coordinates_com).transpose(0, -1), ).transpose(0, -1) ) + coordinates_com @@ -474,13 +397,13 @@ def equivariance_test_utils(): x_translation = torch.randn( size=(1, 3), ) - translation = lambda x: x + x_translation + translation = lambda x: x + x_translation.to(x.device, x.dtype) # generate random quaternion and rotation matrix q = generate_uniform_quaternion() rotation_matrix = rotation_matrix_from_quaternion(q) - rotation = lambda x: apply_rotation_matrix(x, rotation_matrix) + rotation = lambda x: apply_rotation_matrix(x, rotation_matrix.to(x.device, x.dtype)) # Define reflection function alpha = torch.distributions.Uniform(-math.pi, math.pi).sample() @@ -491,6 +414,6 @@ def equivariance_test_utils(): p = torch.eye(3) - 2 * v.T @ v - reflection = lambda x: x @ p + reflection = lambda x: x @ p.to(x.device, x.dtype) return translation, rotation, reflection diff --git a/modelforge/tests/data/best_SchNet-PhAlkEthOH-epoch=00.ckpt b/modelforge/tests/data/best_SchNet-PhAlkEthOH-epoch=00.ckpt new file mode 100644 index 00000000..cbab2382 Binary files /dev/null and b/modelforge/tests/data/best_SchNet-PhAlkEthOH-epoch=00.ckpt differ diff --git a/modelforge/tests/data/config.toml b/modelforge/tests/data/config.toml new file mode 100644 index 00000000..6dbec30b --- /dev/null +++ b/modelforge/tests/data/config.toml @@ -0,0 +1,91 @@ +[potential] +potential_name = "SchNet" + +[potential.core_parameter] +number_of_radial_basis_functions = 20 +maximum_interaction_radius = "5.0 angstrom" +number_of_interaction_modules = 3 +number_of_filters = 32 +shared_interactions = false +predicted_properties = ["per_atom_energy"] +predicted_dim = [1] + +[potential.core_parameter.activation_function_parameter] +activation_function_name = "ShiftedSoftplus" + +[potential.core_parameter.featurization] +properties_to_featurize = ['atomic_number'] +[potential.core_parameter.featurization.atomic_number] +maximum_atomic_number = 101 +number_of_per_atom_features = 32 + +[potential.postprocessing_parameter] +properties_to_process = ['per_atom_energy'] +[potential.postprocessing_parameter.per_atom_energy] +normalize = true +from_atom_to_system_reduction = true +keep_per_atom_property = true +[potential.postprocessing_parameter.general_postprocessing_operation] +calculate_molecular_self_energy = true + +[dataset] +dataset_name = "QM9" +version_select = "nc_1000_v0" +num_workers = 4 +pin_memory = true + +[training] +number_of_epochs = 2 +remove_self_energies = true +batch_size = 128 +lr = 1e-3 +monitor = "val/per_system_energy/rmse" +shift_center_of_mass_to_origin = false + +[training.experiment_logger] +logger_name = "tensorboard" + +[training.experiment_logger.tensorboard_configuration] +save_dir = "logs" + +[training.lr_scheduler] +scheduler_name = "ReduceLROnPlateau" +frequency = 1 +mode = "min" +factor = 0.1 +patience = 10 +cooldown = 5 +min_lr = 1e-8 +threshold = 0.1 +threshold_mode = "abs" +interval = "epoch" + +[training.loss_parameter] +loss_components = ['per_system_energy', 'per_atom_force'] # use + +[training.loss_parameter.weight] +per_system_energy = 0.999 #NOTE: reciprocal units +per_atom_force = 0.001 + + +[training.early_stopping] +verbose = true +min_delta = 0.001 +patience = 50 + +[training.splitting_strategy] +name = "random_record_splitting_strategy" +data_split = [0.8, 0.1, 0.1] +seed = 42 + +[runtime] +verbose = true +save_dir = "lightning_logs" +experiment_name = "{potential_name}_{dataset_name}" +local_cache_dir = "./cache" +accelerator = "cpu" +number_of_nodes = 1 +devices = 1 #[0,1,2,3] +checkpoint_path = "None" +simulation_environment = "PyTorch" +log_every_n_steps = 1 diff --git a/modelforge/tests/data/conv.ipynb b/modelforge/tests/data/conv.ipynb new file mode 100644 index 00000000..e8494257 --- /dev/null +++ b/modelforge/tests/data/conv.ipynb @@ -0,0 +1,186 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from modelforge.dataset.dataset import NNPInput\n", + "import pickle\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_283011/2353297571.py:2: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " p = torch.load('positions.pt')\n" + ] + }, + { + "data": { + "text/plain": [ + "tensor([[ 0.0109, 0.1550, -0.0058],\n", + " [-0.0034, 0.0045, 0.0078],\n", + " [ 0.0782, -0.0711, -0.0857],\n", + " [ 0.1484, -0.1434, 0.0173],\n", + " [ 0.0706, -0.0646, 0.1236],\n", + " [ 0.0694, -0.0576, 0.2425],\n", + " [ 0.1155, 0.1854, 0.0032],\n", + " [-0.0266, 0.1878, -0.1031],\n", + " [-0.0471, 0.2049, 0.0725],\n", + " [-0.1089, -0.0258, 0.0012],\n", + " [ 0.1281, -0.2513, 0.0159],\n", + " [ 0.2569, -0.1269, 0.0154]], requires_grad=True)" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# save positions to a file\n", + "p = torch.load('positions.pt')\n", + "p" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "atomic_numbers=torch.tensor([6, 6, 8, 6, 6, 8, 1, 1, 1, 1, 1, 1], dtype=torch.int32)\n", + "position = p\n", + "atomic_subsystem_indices=torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.int32)\n", + "\n", + "total_charge=torch.tensor([0], dtype=torch.int32)\n", + "pair_list=torch.tensor([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3,\n", + " 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,\n", + " 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6,\n", + " 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8,\n", + " 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9,\n", + " 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11,\n", + " 11, 11, 11, 11, 11, 11],\n", + " [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 2, 3, 4, 5, 6, 7,\n", + " 8, 9, 10, 11, 0, 1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2,\n", + " 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 5, 6, 7, 8, 9, 10,\n", + " 11, 0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5,\n", + " 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 8, 9, 10, 11, 0, 1,\n", + " 2, 3, 4, 5, 6, 7, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8,\n", + " 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 0, 1, 2, 3, 4,\n", + " 5, 6, 7, 8, 9, 10]])\n", + "partial_charge=None\n", + "box_vectors=torch.tensor([[0., 0., 0.],\n", + " [0., 0., 0.],\n", + " [0., 0., 0.]])\n", + "is_periodic=torch.tensor([0.])\n", + "\n", + "\n", + "\n", + "\n", + "nnp_input = NNPInput(atomic_numbers=atomic_numbers, positions=position, atomic_subsystem_indices=atomic_subsystem_indices, total_charge=total_charge, pair_list=pair_list, partial_charge=partial_charge, box_vectors=box_vectors, is_periodic=is_periodic)\n", + "\n", + "# save as pickle\n", + "with open('nnp_input.pkl', 'wb') as f:\n", + " pickle.dump(nnp_input, f)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Example usage for loading and migrating\n", + "with open('mf_input.pkl', 'rb') as f:\n", + " old_instance = pickle.load(f)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "old_instance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import pickle\n", + "from modelforge.datasets import NNPInput\n", + "# Define a migration function to map old instance to new instance\n", + "def migrate_nnpinput(old_instance):\n", + " # Extract the attributes from the old class\n", + " atomic_numbers = old_instance.atomic_numbers\n", + " positions = old_instance.positions\n", + " atomic_subsystem_indices = old_instance.atomic_subsystem_indices\n", + " total_charge = old_instance.total_charge\n", + " pair_list = getattr(old_instance, \"pair_list\", None) # Optional attributes\n", + " partial_charge = getattr(old_instance, \"partial_charge\", None)\n", + " box_vectors = getattr(old_instance, \"box_vectors\", torch.zeros(3, 3)) # Default box_vectors\n", + " is_periodic = getattr(old_instance, \"is_periodic\", torch.tensor([False])) # Default is_periodic\n", + "\n", + " # Create an instance of the new NNPInput class with migrated attributes\n", + " new_instance = NNPInput(\n", + " atomic_numbers=atomic_numbers,\n", + " positions=positions,\n", + " atomic_subsystem_indices=atomic_subsystem_indices,\n", + " total_charge=total_charge,\n", + " box_vectors=box_vectors,\n", + " is_periodic=is_periodic,\n", + " pair_list=pair_list,\n", + " partial_charge=partial_charge,\n", + " )\n", + "\n", + " return new_instance\n", + "\n", + "# Example usage for loading and migrating\n", + "with open('old_nnpinput.pickle', 'rb') as f:\n", + " old_instance = pickle.load(f)\n", + "\n", + "# Migrate old instance to the new class\n", + "new_instance = migrate_nnpinput(old_instance)\n", + "\n", + "# Now you can work with the new_instance using the updated class definition\n", + "print(new_instance)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "modelforge", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/modelforge/tests/data/dataset_defaults/QM9_dataset_statistic.toml b/modelforge/tests/data/dataset_defaults/QM9_dataset_statistic.toml new file mode 100644 index 00000000..fa6d70ae --- /dev/null +++ b/modelforge/tests/data/dataset_defaults/QM9_dataset_statistic.toml @@ -0,0 +1,10 @@ +[atomic_self_energies] +H = "-1313.4668615546 kilojoule_per_mole" +C = "-99366.70745535441 kilojoule_per_mole" +N = "-143309.9379722722 kilojoule_per_mole" +O = "-197082.0671774158 kilojoule_per_mole" +F = "-261811.54555874597 kilojoule_per_mole" + +[training_dataset_statistics] +per_atom_energy_mean = "-402.9165610695209 kilojoule_per_mole" +per_atom_energy_stddev = "25.013382078330697 kilojoule_per_mole" diff --git a/modelforge/tests/data/dataset_defaults/ani1x.toml b/modelforge/tests/data/dataset_defaults/ani1x.toml index be7fc94a..c41e86d6 100644 --- a/modelforge/tests/data/dataset_defaults/ani1x.toml +++ b/modelforge/tests/data/dataset_defaults/ani1x.toml @@ -1,3 +1,5 @@ [dataset] dataset_name = "ANI1x" version_select = "nc_1000_v0" +num_workers = 4 +pin_memory = true diff --git a/modelforge/tests/data/dataset_defaults/ani2x.toml b/modelforge/tests/data/dataset_defaults/ani2x.toml index ce6232c0..fb1a2bf8 100644 --- a/modelforge/tests/data/dataset_defaults/ani2x.toml +++ b/modelforge/tests/data/dataset_defaults/ani2x.toml @@ -1,3 +1,5 @@ [dataset] dataset_name = "ANI2x" version_select = "nc_1000_v0" +num_workers = 4 +pin_memory = true \ No newline at end of file diff --git a/modelforge/tests/data/dataset_defaults/phalkethoh.toml b/modelforge/tests/data/dataset_defaults/phalkethoh.toml index 5e89f1cd..ddac0202 100644 --- a/modelforge/tests/data/dataset_defaults/phalkethoh.toml +++ b/modelforge/tests/data/dataset_defaults/phalkethoh.toml @@ -1,3 +1,5 @@ [dataset] dataset_name = "PHALKETHOH" -version_select = "nc_1000_v0" +version_select = "nc_1000_v1" +num_workers = 4 +pin_memory = true \ No newline at end of file diff --git a/modelforge/tests/data/dataset_defaults/qm9.toml b/modelforge/tests/data/dataset_defaults/qm9.toml index 43759ece..797845c7 100644 --- a/modelforge/tests/data/dataset_defaults/qm9.toml +++ b/modelforge/tests/data/dataset_defaults/qm9.toml @@ -1,3 +1,5 @@ [dataset] dataset_name = "QM9" -version_select = "nc_1000_v0" \ No newline at end of file +version_select = "nc_1000_v0" +num_workers = 4 +pin_memory = true \ No newline at end of file diff --git a/modelforge/tests/data/dataset_defaults/spice1.toml b/modelforge/tests/data/dataset_defaults/spice1.toml index a89eec95..decd0e2b 100644 --- a/modelforge/tests/data/dataset_defaults/spice1.toml +++ b/modelforge/tests/data/dataset_defaults/spice1.toml @@ -1,3 +1,5 @@ [dataset] dataset_name = "SPICE1" version_select = "nc_1000_v0" +num_workers = 4 +pin_memory = true \ No newline at end of file diff --git a/modelforge/tests/data/dataset_defaults/spice1_openff.toml b/modelforge/tests/data/dataset_defaults/spice1_openff.toml index f7957ee3..5f3d60e7 100644 --- a/modelforge/tests/data/dataset_defaults/spice1_openff.toml +++ b/modelforge/tests/data/dataset_defaults/spice1_openff.toml @@ -1,3 +1,5 @@ [dataset] dataset_name = "SPICE1_OPENFF" version_select = "nc_1000_v0" +num_workers = 4 +pin_memory = true \ No newline at end of file diff --git a/modelforge/tests/data/dataset_defaults/spice2.toml b/modelforge/tests/data/dataset_defaults/spice2.toml index ebf7c6e0..8f655b30 100644 --- a/modelforge/tests/data/dataset_defaults/spice2.toml +++ b/modelforge/tests/data/dataset_defaults/spice2.toml @@ -1,3 +1,5 @@ [dataset] dataset_name = "SPICE2" version_select = "nc_1000_v0" +num_workers = 4 +pin_memory = true \ No newline at end of file diff --git a/modelforge/tests/data/mf_input.pkl b/modelforge/tests/data/mf_input.pkl new file mode 100644 index 00000000..e2dd7077 Binary files /dev/null and b/modelforge/tests/data/mf_input.pkl differ diff --git a/modelforge/tests/data/nnp_input.pkl b/modelforge/tests/data/nnp_input.pkl new file mode 100644 index 00000000..eac3471b Binary files /dev/null and b/modelforge/tests/data/nnp_input.pkl differ diff --git a/modelforge/tests/data/positions.pt b/modelforge/tests/data/positions.pt new file mode 100644 index 00000000..1fe209a4 Binary files /dev/null and b/modelforge/tests/data/positions.pt differ diff --git a/modelforge/tests/data/potential_defaults/aimnet2.toml b/modelforge/tests/data/potential_defaults/aimnet2.toml new file mode 100644 index 00000000..223df414 --- /dev/null +++ b/modelforge/tests/data/potential_defaults/aimnet2.toml @@ -0,0 +1,26 @@ +[potential] +potential_name = "AimNet2" + +[potential.core_parameter] +number_of_radial_basis_functions = 64 +number_of_vector_features = 8 +maximum_interaction_radius = "5.0 angstrom" +number_of_interaction_modules = 3 +predicted_properties = ["per_atom_energy"] +predicted_dim = [1] + +[potential.core_parameter.activation_function_parameter] +activation_function_name = "GeLU" + +[potential.core_parameter.featurization] +properties_to_featurize = ['atomic_number'] +[potential.core_parameter.featurization.atomic_number] +maximum_atomic_number = 101 +number_of_per_atom_features = 64 + +[potential.postprocessing_parameter] +properties_to_process = ['per_atom_energy'] +[potential.postprocessing_parameter.per_atom_energy] +normalize = true +from_atom_to_system_reduction = true +keep_per_atom_property = true diff --git a/modelforge/tests/data/potential_defaults/ani2x.toml b/modelforge/tests/data/potential_defaults/ani2x.toml index 05cae5b9..8ecc4872 100644 --- a/modelforge/tests/data/potential_defaults/ani2x.toml +++ b/modelforge/tests/data/potential_defaults/ani2x.toml @@ -1,17 +1,26 @@ [potential] -model_name = "ANI2x" +potential_name = "ANI2x" [potential.core_parameter] angle_sections = 4 -radial_max_distance = "5.1 angstrom" -radial_min_distance = "0.8 angstrom" +maximum_interaction_radius = "5.1 angstrom" +minimum_interaction_radius = "0.8 angstrom" number_of_radial_basis_functions = 16 -angular_max_distance = "3.5 angstrom" -angular_min_distance = "0.8 angstrom" +maximum_interaction_radius_for_angular_features = "3.5 angstrom" +minimum_interaction_radius_for_angular_features = "0.8 angstrom" angular_dist_divisions = 8 +predicted_properties = ["per_atom_energy"] +predicted_dim = [1] + +[potential.core_parameter.activation_function_parameter] +activation_function_name = "CeLU" + +[potential.core_parameter.activation_function_parameter.activation_function_arguments] +alpha = 0.1 [potential.postprocessing_parameter] +properties_to_process = ['per_atom_energy'] [potential.postprocessing_parameter.per_atom_energy] normalize = true -from_atom_to_molecule_reduction = true +from_atom_to_system_reduction = true keep_per_atom_property = true diff --git a/modelforge/tests/data/potential_defaults/painn.toml b/modelforge/tests/data/potential_defaults/painn.toml index 70292773..163e1f7a 100644 --- a/modelforge/tests/data/potential_defaults/painn.toml +++ b/modelforge/tests/data/potential_defaults/painn.toml @@ -1,18 +1,27 @@ [potential] -model_name = "PaiNN" +potential_name = "PaiNN" [potential.core_parameter] - -max_Z = 101 -number_of_atom_features = 32 -number_of_radial_basis_functions = 20 -cutoff = "5.0 angstrom" +number_of_radial_basis_functions = 16 +maximum_interaction_radius = "5.0 angstrom" number_of_interaction_modules = 3 shared_interactions = false shared_filters = false +predicted_properties = ["per_atom_energy"] +predicted_dim = [1] + +[potential.core_parameter.activation_function_parameter] +activation_function_name = "SiLU" + +[potential.core_parameter.featurization] +properties_to_featurize = ['atomic_number'] +[potential.core_parameter.featurization.atomic_number] +maximum_atomic_number = 101 +number_of_per_atom_features = 32 [potential.postprocessing_parameter] +properties_to_process = ['per_atom_energy'] [potential.postprocessing_parameter.per_atom_energy] normalize = true -from_atom_to_molecule_reduction = true +from_atom_to_system_reduction = true keep_per_atom_property = true diff --git a/modelforge/tests/data/potential_defaults/physnet.toml b/modelforge/tests/data/potential_defaults/physnet.toml index 68b76d91..72a808b9 100644 --- a/modelforge/tests/data/potential_defaults/physnet.toml +++ b/modelforge/tests/data/potential_defaults/physnet.toml @@ -1,17 +1,27 @@ [potential] -model_name = "PhysNet" +potential_name = "PhysNet" [potential.core_parameter] -max_Z = 101 -number_of_atom_features = 64 number_of_radial_basis_functions = 16 -cutoff = "5.0 angstrom" +maximum_interaction_radius = "5.0 angstrom" number_of_interaction_residual = 3 number_of_modules = 5 +predicted_properties = ["per_atom_energy"] +predicted_dim = [1] + +[potential.core_parameter.activation_function_parameter] +activation_function_name = "ShiftedSoftplus" + +[potential.core_parameter.featurization] +properties_to_featurize = ['atomic_number'] +[potential.core_parameter.featurization.atomic_number] +maximum_atomic_number = 101 +number_of_per_atom_features = 32 [potential.postprocessing_parameter] +properties_to_process = ['per_atom_energy'] [potential.postprocessing_parameter.per_atom_energy] normalize = true -from_atom_to_molecule_reduction = true +from_atom_to_system_reduction = true keep_per_atom_property = true diff --git a/modelforge/tests/data/potential_defaults/sake.toml b/modelforge/tests/data/potential_defaults/sake.toml index d8fb2cc5..84e5664e 100644 --- a/modelforge/tests/data/potential_defaults/sake.toml +++ b/modelforge/tests/data/potential_defaults/sake.toml @@ -1,17 +1,27 @@ [potential] -model_name = "SAKE" +potential_name = "SAKE" [potential.core_parameter] -max_Z = 101 -number_of_atom_features = 64 -number_of_radial_basis_functions = 50 -cutoff = "5.0 angstrom" +number_of_radial_basis_functions = 11 +maximum_interaction_radius = "5.0 angstrom" number_of_interaction_modules = 6 number_of_spatial_attention_heads = 4 +predicted_properties = ["per_atom_energy"] +predicted_dim = [1] + +[potential.core_parameter.activation_function_parameter] +activation_function_name = "SiLU" + +[potential.core_parameter.featurization] +properties_to_featurize = ['atomic_number'] +[potential.core_parameter.featurization.atomic_number] +maximum_atomic_number = 101 +number_of_per_atom_features = 11 [potential.postprocessing_parameter] +properties_to_process = ['per_atom_energy'] [potential.postprocessing_parameter.per_atom_energy] normalize = true -from_atom_to_molecule_reduction = true +from_atom_to_system_reduction = true keep_per_atom_property = true diff --git a/modelforge/tests/data/potential_defaults/schnet.toml b/modelforge/tests/data/potential_defaults/schnet.toml index f5b0094d..70cdcde5 100644 --- a/modelforge/tests/data/potential_defaults/schnet.toml +++ b/modelforge/tests/data/potential_defaults/schnet.toml @@ -1,19 +1,40 @@ +# ------------------------------------------------------------ # [potential] -model_name = "SchNet" - +potential_name = "SchNet" +# ------------------------------------------------------------ # [potential.core_parameter] -max_Z = 101 -number_of_atom_features = 32 -number_of_radial_basis_functions = 20 -cutoff = "5.0 angstrom" +number_of_radial_basis_functions = 16 +maximum_interaction_radius = "5.0 angstrom" number_of_interaction_modules = 3 number_of_filters = 32 shared_interactions = false - +predicted_properties = ["per_atom_energy", 'per_atom_charge'] +predicted_dim = [1, 1] +# ------------------------------------------------------------ # +[potential.core_parameter.activation_function_parameter] +activation_function_name = "ShiftedSoftplus" +# ------------------------------------------------------------ # +[potential.core_parameter.featurization] +properties_to_featurize = ['atomic_number'] +[potential.core_parameter.featurization.atomic_number] +maximum_atomic_number = 101 +number_of_per_atom_features = 32 +# ------------------------------------------------------------ # [potential.postprocessing_parameter] +properties_to_process = ['per_atom_energy'] [potential.postprocessing_parameter.per_atom_energy] normalize = true -from_atom_to_molecule_reduction = true +from_atom_to_system_reduction = true keep_per_atom_property = true -[potential.postprocessing_parameter.general_postprocessing_operation] -calculate_molecular_self_energy = true +[potential.postprocessing_parameter.per_atom_charge] +conserve = true +conserve_strategy = "default" +# ------------------------------------------------------------ # +# [potential.postprocessing_parameter.per_atom_charge.coulomb_potential] +# electrostatic_strategy = "coulomb" +# maximum_interaction_radius = "10.0 angstrom" +# from_atom_to_system_reduction = true +# keep_per_atom_property = true + +# [potential.postprocessing_parameter.general_postprocessing_operation] +# calculate_molecular_self_energy = true diff --git a/modelforge/tests/data/potential_defaults/tensornet.toml b/modelforge/tests/data/potential_defaults/tensornet.toml new file mode 100644 index 00000000..31fbc8ff --- /dev/null +++ b/modelforge/tests/data/potential_defaults/tensornet.toml @@ -0,0 +1,26 @@ +[potential] +potential_name = "TensorNet" + +[potential.core_parameter] +number_of_per_atom_features = 8 +number_of_interaction_layers = 2 +number_of_radial_basis_functions = 16 +maximum_interaction_radius = "5.1 angstrom" +minimum_interaction_radius = "0.0 angstrom" +maximum_atomic_number = 128 +equivariance_invariance_group = "O(3)" +predicted_properties = ["per_atom_energy"] +predicted_dim = [1] + +[potential.core_parameter.activation_function_parameter] +activation_function_name = "SiLU" + +[potential.postprocessing_parameter] +properties_to_process = ['per_atom_energy'] +[potential.postprocessing_parameter.per_atom_energy] +normalize = true +from_atom_to_system_reduction = true +keep_per_atom_property = true + +[potential.postprocessing_parameter.general_postprocessing_operation] +calculate_molecular_self_energy = true diff --git a/modelforge/tests/data/runtime_defaults/runtime.toml b/modelforge/tests/data/runtime_defaults/runtime.toml index 7520d1d5..6893eaef 100644 --- a/modelforge/tests/data/runtime_defaults/runtime.toml +++ b/modelforge/tests/data/runtime_defaults/runtime.toml @@ -1,7 +1,10 @@ [runtime] -experiment_name = "exp_test" -accelerator = "cpu" -num_nodes = 1 -devices = 1 #[0,1,2,3] +experiment_name = "{potential_name}_{dataset_name}" local_cache_dir = "./cache" -save_dir = "test" +accelerator = "cpu" +number_of_nodes = 1 +devices = 1 #[0,1,2,3] +checkpoint_path = "None" +simulation_environment = "PyTorch" +log_every_n_steps = 50 +verbose = true diff --git a/modelforge/tests/data/tensornet_input.pt b/modelforge/tests/data/tensornet_input.pt new file mode 100644 index 00000000..450edcf0 Binary files /dev/null and b/modelforge/tests/data/tensornet_input.pt differ diff --git a/modelforge/tests/data/tensornet_interaction.pt b/modelforge/tests/data/tensornet_interaction.pt new file mode 100644 index 00000000..8c8b4dcb Binary files /dev/null and b/modelforge/tests/data/tensornet_interaction.pt differ diff --git a/modelforge/tests/data/tensornet_radial_symmetry_features.pt b/modelforge/tests/data/tensornet_radial_symmetry_features.pt new file mode 100644 index 00000000..2227c945 Binary files /dev/null and b/modelforge/tests/data/tensornet_radial_symmetry_features.pt differ diff --git a/modelforge/tests/data/tensornet_representation.pt b/modelforge/tests/data/tensornet_representation.pt new file mode 100644 index 00000000..410e6095 Binary files /dev/null and b/modelforge/tests/data/tensornet_representation.pt differ diff --git a/modelforge/tests/data/torchani_parameters.state b/modelforge/tests/data/torchani_parameters.state new file mode 100644 index 00000000..6ed96faa Binary files /dev/null and b/modelforge/tests/data/torchani_parameters.state differ diff --git a/modelforge/tests/data/training_defaults/default.toml b/modelforge/tests/data/training_defaults/default.toml index ec73b2a3..122dfebc 100644 --- a/modelforge/tests/data/training_defaults/default.toml +++ b/modelforge/tests/data/training_defaults/default.toml @@ -1,40 +1,54 @@ [training] -nr_of_epochs = 5 -num_nodes = 1 -devices = 1 # [0,1,2,3] +number_of_epochs = 2 remove_self_energies = true +shift_center_of_mass_to_origin = false batch_size = 128 - -[training.training_parameter] -lr = 1e-3 - +lr = 5e-4 +monitor = "val/per_system_energy/rmse" # Common monitor key +plot_frequency = 1 +# ------------------------------------------------------------ # [training.experiment_logger] -logger_name = "tensorboard" - -[training.training_parameter.lr_scheduler_config] +logger_name = "tensorboard" # this will set which logger to use +[training.experiment_logger.tensorboard_configuration] +save_dir = "logs" +# ------------------------------------------------------------ # +[training.experiment_logger.wandb_configuration] +save_dir = "logs" +project = "tests" +group = "exp00" +log_model = true +job_type = "testing" +tags = ["v_0.1.0"] +notes = "testing training" +# ------------------------------------------------------------ # +# Learning Rate Scheduler Configuration +[training.lr_scheduler] +scheduler_name = "ReduceLROnPlateau" frequency = 1 +interval = "epoch" +monitor = "val/per_system_energy/rmse" mode = "min" factor = 0.1 patience = 10 -cooldown = 5 -min_lr = 1e-8 threshold = 0.1 threshold_mode = "abs" -monitor = "val/per_molecule_energy/rmse" -interval = "epoch" - -[training.training_parameter.loss_parameter] -loss_property = ['per_molecule_energy', 'per_atom_force'] # use . -[training.training_parameter.loss_parameter.weight] -per_molecule_energy = 0.999 #NOTE: reciprocal units -per_atom_force = 0.001 - +cooldown = 5 +min_lr = 1e-8 +eps = 1e-8 # Optional, default is 1e-8 +# ------------------------------------------------------------ # +[training.loss_parameter] +loss_components = ['per_system_energy'] #, 'per_atom_force'] +# ------------------------------------------------------------ # +[training.loss_parameter.weight] +per_system_energy = 1.0 +# ------------------------------------------------------------ # [training.early_stopping] verbose = true -monitor = "val/per_molecule_energy/rmse" min_delta = 0.001 patience = 50 - +# ------------------------------------------------------------ # [training.splitting_strategy] name = "random_record_splitting_strategy" data_split = [0.8, 0.1, 0.1] +seed = 42 +# ------------------------------------------------------------ # diff --git a/modelforge/tests/helper_functions.py b/modelforge/tests/helper_functions.py index e69de29b..1f518b95 100644 --- a/modelforge/tests/helper_functions.py +++ b/modelforge/tests/helper_functions.py @@ -0,0 +1,84 @@ +from typing import Optional, Literal + + +def _add_per_atom_charge_to_predicted_properties(config): + config["potential"].core_parameter.predicted_properties.append("per_atom_charge") + config["potential"].core_parameter.predicted_dim.append(1) + return config + + +def _add_per_atom_charge_to_properties_to_process(config): + config["potential"].postprocessing_parameter.properties_to_process.append( + "per_atom_charge" + ) + from modelforge.potential.parameters import PerAtomCharge + + config["potential"].postprocessing_parameter.per_atom_charge = PerAtomCharge( + conserve=True, conserve_strategy="default" + ) + + return config + + +def _add_electrostatic_to_predicted_properties(config): + from modelforge.potential.parameters import ElectrostaticPotential + from openff.units import unit + + config["potential"].postprocessing_parameter.properties_to_process.append( + "electrostatic_potential" + ) + config["potential"].postprocessing_parameter.electrostatic_potential = ( + ElectrostaticPotential( + electrostatic_strategy="coulomb", + maximum_interaction_radius=10.0 * unit.angstrom, + ) + ) + + return config + + +def setup_potential_for_test( + potential_name: str, + use: str, + use_default_dataset_statistic: bool = True, + use_training_mode_neighborlist: bool = True, + jit: bool = False, + potential_seed: Optional[int] = None, + simulation_environment: Literal["PyTorch", "JAX"] = "PyTorch", + local_cache_dir: Optional[str] = None, +): + from modelforge.potential import NeuralNetworkPotentialFactory + from modelforge.tests.test_potentials import load_configs_into_pydantic_models + + if simulation_environment == "JAX": + assert use == "inference", "JAX only supports inference mode" + + # read default parameters + config = load_configs_into_pydantic_models(potential_name, "qm9") + # override defaults to match reference implementation in spk + + if local_cache_dir is not None: + config["runtime"].local_cache_dir = local_cache_dir + + if use == "training": + trainer = NeuralNetworkPotentialFactory.generate_trainer( + potential_parameter=config["potential"], + runtime_parameter=config["runtime"], + training_parameter=config["training"], + dataset_parameter=config["dataset"], + potential_seed=potential_seed, + use_default_dataset_statistic=use_default_dataset_statistic, + ) + potential = trainer.lightning_module.potential + else: + potential = NeuralNetworkPotentialFactory.generate_potential( + potential_parameter=config["potential"], + training_parameter=config["training"], + dataset_parameter=config["dataset"], + potential_seed=potential_seed, + simulation_environment=simulation_environment, + use_training_mode_neighborlist=use_training_mode_neighborlist, + jit=jit, + ) + + return potential diff --git a/modelforge/tests/precalculated_values.py b/modelforge/tests/precalculated_values.py index 48a4b9c1..24aa1620 100644 --- a/modelforge/tests/precalculated_values.py +++ b/modelforge/tests/precalculated_values.py @@ -106,13 +106,13 @@ def setup_single_methane_input(): ) E = torch.tensor([0.0], requires_grad=True) atomic_subsystem_indices = torch.tensor([0, 0, 0, 0, 0], dtype=torch.int32) - from modelforge.dataset.dataset import NNPInput + from modelforge.utils.prop import NNPInput modelforge_methane = NNPInput( atomic_numbers=atomic_numbers, positions=positions, atomic_subsystem_indices=atomic_subsystem_indices, - total_charge=torch.tensor([0], dtype=torch.int32), + per_system_total_charge=torch.tensor([0], dtype=torch.int32), ) # ------------------------------------ # @@ -613,7 +613,7 @@ def calculate_reference(): def provide_reference_values_for_test_ani_test_compute_rsf_with_diagonal_batching(): def calculate_reference(): from torchani.aev import neighbor_pairs_nopbc - from modelforge.potential.models import Pairlist + from modelforge.potential.potential import Pairlist from modelforge.tests.test_ani import setup_two_methanes # ------------ general setup -------------# @@ -1070,7 +1070,7 @@ def calculate_values(): ShfZ = (torch.linspace(0, math.pi, angle_sections + 1) + angle_start)[:-1] # set up relevant system properties from modelforge.tests.test_ani import setup_methane - from modelforge.potential.models import Pairlist + from modelforge.potential.potential import Pairlist from modelforge.potential.utils import triple_by_molecule device = torch.device("cpu") @@ -2279,3 +2279,127 @@ def softplus_inverse(x): dtype=np.float32, ) return rbf + + +def prepare_values_for_test_tensornet_compare_radial_symmetry_features( + d_ij, min_distance, max_distance, number_of_radial_basis_functions, trainable, seed +): + from torchmdnet.models.utils import ExpNormalSmearing + + torch.manual_seed(seed) + rsf_tn = ExpNormalSmearing( + cutoff_lower=min_distance, + cutoff_upper=max_distance, + num_rbf=number_of_radial_basis_functions, + trainable=trainable, + ) + tn_r = rsf_tn(d_ij) + torch.save(tn_r, "modelforge/tests/data/tensornet_radial_symmetry_features.pt") + + return tn_r + + +def prepare_values_for_test_tensornet_input(mf_input, seed): + from torchmdnet.models.utils import OptimizedDistance + + pos, batch = (mf_input.positions, mf_input.atomic_subsystem_indices) + + torch.manual_seed(seed) + distance_module = OptimizedDistance( + cutoff_lower=0.0, + cutoff_upper=5.0, + max_num_pairs=153, + return_vecs=True, + loop=False, + check_errors=False, + resize_to_fit=False, # not self.static_shapes + box=None, + long_edge_index=False, + ) + + edge_index, edge_weight, edge_vec = distance_module(pos, batch, None) + torch.save( + (edge_index, edge_weight, edge_vec), "modelforge/tests/data/tensornet_input.pt" + ) + + return edge_index, edge_weight, edge_vec + + +def prepare_values_for_test_tensornet_representation( + nnp_input, + hidden_channels, + number_of_radial_basis_functions, + activation_function, + min_distance, + max_distance, + trainable_rbf, + max_atomic_number, + seed, +): + from torchmdnet.models.tensornet import TensorEmbedding + from torchmdnet.models.utils import ExpNormalSmearing + + torch.manual_seed(seed) + # TensorNet embedding modules setup + tensor_embedding = TensorEmbedding( + hidden_channels, + number_of_radial_basis_functions, + activation_function, + min_distance, + max_distance, + trainable_rbf, + max_atomic_number, + ) + + distance_expansion = ExpNormalSmearing( + min_distance, max_distance, number_of_radial_basis_functions, trainable_rbf + ) + + # calculate embedding + edge_attr = distance_expansion(nnp_input.d_ij.squeeze(-1) * 10) # Note: in angstrom + + tn_X = tensor_embedding( + nnp_input.atomic_numbers, + nnp_input.pair_indices, + nnp_input.d_ij.squeeze(-1) * 10, # Note: in angstrom + nnp_input.r_ij / nnp_input.d_ij, # edge_vec_norm in angstrom + edge_attr, + ) + torch.save(tn_X, "modelforge/tests/data/tensornet_representation.pt") + + return tn_X + + +def prepare_values_for_test_tensornet_interaction( + X, + nnp_input, + radial_feature_vector, + atomic_charges, + hidden_channels, + number_of_radial_basis_functions, + activation_function, + min_distance, + max_distance, + seed, +): + from torchmdnet.models.tensornet import Interaction + + torch.manual_seed(seed) + tn_interaction = Interaction( + number_of_radial_basis_functions, + hidden_channels, + activation_function, + min_distance, + max_distance, + "O(3)", + ) + tn_X = tn_interaction( + X, + nnp_input.pair_indices, + nnp_input.d_ij.squeeze(-1) * 10, + radial_feature_vector.squeeze(1), + atomic_charges, + ) + torch.save(tn_X, "modelforge/tests/data/tensornet_interaction.pt") + + return tn_X diff --git a/modelforge/tests/test_aimnet2.py b/modelforge/tests/test_aimnet2.py new file mode 100644 index 00000000..ce6385f8 --- /dev/null +++ b/modelforge/tests/test_aimnet2.py @@ -0,0 +1,227 @@ +import pytest +import torch +from openff.units import unit + +from modelforge.tests.helper_functions import setup_potential_for_test + + +@pytest.fixture(scope="session") +def prep_temp_dir(tmp_path_factory): + fn = tmp_path_factory.mktemp("test_aimnet2_temp") + return fn + + +def test_initialize_model(prep_temp_dir): + """Test initialization of the Schnet model.""" + + # read default parameters + model = setup_potential_for_test( + "aimnet2", "training", local_cache_dir=str(prep_temp_dir) + ) + + assert model is not None, "Aimnet2 model should be initialized." + + +def test_radial_symmetry_function_regression(): + from modelforge.potential import SchnetRadialBasisFunction + + # define radial symmetry function bounds and subdivisions + num_bins = 10 + lower_bound = unit.Quantity(0.5, unit.angstrom) + upper_bound = unit.Quantity(5.0, unit.angstrom) + + radial_symmetry_function_module = SchnetRadialBasisFunction( + num_bins, + min_distance=lower_bound.to(unit.nanometer).m, + max_distance=upper_bound.to(unit.nanometer).m, + ) + + # example interatomic distances in Angstroms, peaks should appear every + # other index and fall off when outside the upper_bound in outputs given + # lower_bound, upper_bound, and number of bins (every 0.5 in bin center + # value) + d_ij = torch.tensor([[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0]]) + + # regression outputs + regression_outputs = torch.tensor( + [ + [ + 6.0653e-01, + 1.0000e00, + 6.0653e-01, + 1.3534e-01, + 1.1109e-02, + 3.3546e-04, + 3.7266e-06, + 1.5230e-08, + 2.2897e-11, + 1.2664e-14, + ], + [ + 1.1109e-02, + 1.3534e-01, + 6.0653e-01, + 1.0000e00, + 6.0653e-01, + 1.3534e-01, + 1.1109e-02, + 3.3546e-04, + 3.7266e-06, + 1.5230e-08, + ], + [ + 3.7266e-06, + 3.3546e-04, + 1.1109e-02, + 1.3534e-01, + 6.0653e-01, + 1.0000e00, + 6.0653e-01, + 1.3534e-01, + 1.1109e-02, + 3.3546e-04, + ], + [ + 2.2897e-11, + 1.5230e-08, + 3.7266e-06, + 3.3546e-04, + 1.1109e-02, + 1.3534e-01, + 6.0653e-01, + 1.0000e00, + 6.0653e-01, + 1.3534e-01, + ], + [ + 2.5767e-18, + 1.2664e-14, + 2.2897e-11, + 1.5230e-08, + 3.7266e-06, + 3.3546e-04, + 1.1109e-02, + 1.3534e-01, + 6.0653e-01, + 1.0000e00, + ], + [ + 5.3110e-27, + 1.9287e-22, + 2.5767e-18, + 1.2664e-14, + 2.2897e-11, + 1.5230e-08, + 3.7266e-06, + 3.3546e-04, + 1.1109e-02, + 1.3534e-01, + ], + [ + 2.0050e-37, + 5.3801e-32, + 5.3110e-27, + 1.9287e-22, + 2.5767e-18, + 1.2664e-14, + 2.2897e-11, + 1.5230e-08, + 3.7266e-06, + 3.3546e-04, + ], + ] + ) + + # module call expects units in nanometers, divide by 10 to correct scale + modelforge_aimnet2_outputs = radial_symmetry_function_module(d_ij / 10.0) + + assert torch.allclose(modelforge_aimnet2_outputs, regression_outputs, atol=1e-4) + + +def test_forward(single_batch_with_batchsize, prep_temp_dir): + """Test initialization of the AIMNet2 model.""" + # read default parameters + aimnet = setup_potential_for_test("aimnet2", "training", potential_seed=42) + + assert aimnet is not None, "Aimnet model should be initialized." + batch = single_batch_with_batchsize(64, "QM9", str(prep_temp_dir)) + + y_hat = aimnet(batch.nnp_input) + + assert y_hat is not None, "Aimnet model should be able to make predictions." + + ref_per_system_energy = torch.tensor( + [ + [0.2630], + [-0.5150], + [-0.2999], + [-0.0297], + [-0.4382], + [-0.1805], + [0.5974], + [0.1769], + [0.0842], + [-0.2955], + [0.1295], + [-0.4067], + [0.4135], + [0.3202], + [0.2481], + [0.6696], + [0.0380], + [0.0834], + [-0.2613], + [-0.8373], + [0.2033], + [0.1554], + [0.0624], + [-0.3643], + [-0.7861], + [-0.0398], + [-0.4675], + [-0.1000], + [0.3265], + [0.2546], + [-0.1597], + [-0.9611], + [0.0653], + [-0.4411], + [0.2587], + [-0.1082], + [0.0461], + [0.0407], + [0.6725], + [0.3874], + [0.3393], + [0.1747], + [0.4048], + [0.1001], + [0.1496], + [0.2432], + [0.3578], + [0.2792], + [-0.3365], + [-0.3329], + [-0.8465], + [0.0463], + [-0.4385], + [0.1224], + [-0.0442], + [0.1029], + [-0.4559], + [-1.1701], + [-0.2714], + [0.0318], + [-0.8579], + [-0.3836], + [0.2487], + [-0.2728], + ], + ) + + assert torch.allclose(y_hat["per_system_energy"], ref_per_system_energy, atol=1e-3) + + +@pytest.mark.xfail(raises=NotImplementedError) +def test_against_original_implementation(): + raise NotImplementedError diff --git a/modelforge/tests/test_ani.py b/modelforge/tests/test_ani.py index 8e5b5314..53712127 100644 --- a/modelforge/tests/test_ani.py +++ b/modelforge/tests/test_ani.py @@ -1,4 +1,15 @@ import pytest +from modelforge.tests.helper_functions import setup_potential_for_test +from importlib import resources +from modelforge.tests import data + +file_path = resources.files(data) / f"torchani_parameters.state" + + +@pytest.fixture(scope="session") +def prep_temp_dir(tmp_path_factory): + fn = tmp_path_factory.mktemp("test_ani_temp") + return fn def setup_methane(): @@ -24,13 +35,13 @@ def setup_methane(): [0, 0, 0, 0, 0], dtype=torch.int32, device=device ) - from modelforge.dataset.dataset import NNPInput + from modelforge.utils.prop import NNPInput nnp_input = NNPInput( atomic_numbers=torch.tensor([6, 1, 1, 1, 1], device=device), positions=coordinates.squeeze(0) / 10, atomic_subsystem_indices=atomic_subsystem_indices, - total_charge=torch.tensor([0.0]), + per_system_total_charge=torch.tensor([0.0]), ) return species, coordinates, device, nnp_input @@ -61,6 +72,16 @@ def setup_two_methanes(): requires_grad=True, device=device, ) + # Specify the translation vector + translation_vector = torch.tensor([1.0, 1.0, 1.0], device=device) + # Translate the second "molecule" without in-place modification + translated_coordinates = ( + coordinates.clone() + ) # Clone the tensor to avoid in-place modification + translated_coordinates[1] = translated_coordinates[1] + translation_vector + + print(translated_coordinates) + # In periodic table, C = 6 and H = 1 mf_species = torch.tensor([6, 1, 1, 1, 1, 6, 1, 1, 1, 1], device=device) ani_species = torch.tensor([[1, 0, 0, 0, 0], [1, 0, 0, 0, 0]], device=device) @@ -69,256 +90,213 @@ def setup_two_methanes(): ) atomic_numbers = mf_species - from modelforge.dataset.dataset import NNPInput + from modelforge.utils.prop import NNPInput nnp_input = NNPInput( atomic_numbers=atomic_numbers, - positions=torch.cat((coordinates[0], coordinates[1]), dim=0) / 10, + positions=torch.cat( + (translated_coordinates[0], translated_coordinates[1]), dim=0 + ) + / 10, atomic_subsystem_indices=atomic_subsystem_indices, - total_charge=torch.tensor([0.0, 0.0]), + per_system_total_charge=torch.tensor([0.0, 0.0]), ) - return ani_species, coordinates, device, nnp_input + return ani_species, translated_coordinates, device, nnp_input @pytest.mark.xfail -def test_forward_and_backward_using_torchani(): - # Test torchani ANI implementation - # Test forward pass and backpropagation through network - +def test_ani(): import torch import torchani - species, coordinates, device, _ = setup_two_methanes() - model = torchani.models.ANI2x(periodic_table_index=False).to(device) + # NOTE: in the following the input data is scaled to provide both + # torchani and modelforge ani the same input but in different units + # NOTE: output unit is Hartree - energy = model((species, coordinates)).energies - derivative = torch.autograd.grad(energy.sum(), coordinates)[0] - per_atom_force = -derivative - - -def test_forward_and_backward(): - # Test modelforge ANI implementation - # Test forward pass and backpropagation through network - from modelforge.potential.ani import ANI2x - from modelforge.tests.test_models import load_configs - import torch + # load reference implementation - # read default parameters - config = load_configs("ani2x", "qm9") - - _, _, _, mf_input = setup_two_methanes() - device = torch.device("cpu") - - # initialize model - model = ANI2x( - **config["potential"]["core_parameter"], - postprocessing_parameter=config["potential"]["postprocessing_parameter"], - ).to(device=device) - energy = model(mf_input) - derivative = torch.autograd.grad( - energy["per_molecule_energy"].sum(), mf_input.positions - )[0] - per_atom_force = -derivative + # get input + species, coordinates, device, mf_input = setup_two_methanes() + # get single model + model = torchani.models.ANI2x(periodic_table_index=False, model_index=0) + # calculate energy for methane + energy = model((species, coordinates)).energies + # get per atom energy + w, torchani_atomic_energies = model.atomic_energies((species, coordinates)) -def test_representation(): - # Compare the reference radial symmetry function - # against the the implemented radial symmetry function - import torch - from modelforge.potential.utils import AniRadialBasisFunction, CosineCutoff - from openff.units import unit - from .precalculated_values import ( - provide_reference_values_for_test_ani_test_compare_rsf, + # compare to reference energy + assert torch.allclose( + torchani_atomic_energies, + torch.tensor( + [ + [-38.0841, -0.5797, -0.5898, -0.6034, -0.6027], + [-38.0841, -0.5797, -0.5898, -0.6034, -0.6027], + ] + ), + rtol=1e-4, ) - # use d_ij in angstrom - d_ij = torch.tensor([[3.5201], [2.6756], [2.1641], [3.0990], [4.5180]]) - radial_cutoff = 5.0 # radial_cutoff - radial_start = 0.8 - radial_dist_divisions = 8 - - # NOTE: we pass in Angstrom to ANI and in nanometer to mf - rsf = AniRadialBasisFunction( - number_of_radial_basis_functions=radial_dist_divisions, - max_distance=radial_cutoff * unit.angstrom, - min_distance=radial_start * unit.angstrom, + # calculate reference ase (substract per atom energy without ase from per + # atom energy with ase) + # NOTE: this is in Hartree + reference_ase = torch.tensor( + [ + [ + -38.08933878049795, + 0.5978583943827134, + 0.5978583943827134, + 0.5978583943827134, + 0.5978583943827134, + ], + [ + -38.08933878049795, + 0.5978583943827134, + 0.5978583943827134, + 0.5978583943827134, + 0.5978583943827134, + ], + ], ) - calculated_rsf = rsf(d_ij / 10) # torch.Size([5,1, 8]) # NOTE: nanometer - cutoff_module = CosineCutoff(radial_cutoff * unit.angstrom) - rcut_ij = cutoff_module(d_ij / 10) # torch.Size([5]) # NOTE: nanometer - reference_rsf = provide_reference_values_for_test_ani_test_compare_rsf() - calculated_rsf = calculated_rsf * rcut_ij - assert torch.allclose(calculated_rsf, reference_rsf, rtol=1e-4) - - -def test_representation_with_diagonal_batching(): - import torch - from modelforge.potential.utils import AniRadialBasisFunction, CosineCutoff - from openff.units import unit - from modelforge.potential.models import Pairlist - from .precalculated_values import ( - provide_reference_values_for_test_ani_test_compute_rsf_with_diagonal_batching, + # ------------------------------------------ # + # setup modelforge potential + potential = setup_potential_for_test( + use="training", + potential_seed=42, + potential_name="ani2x", + jit=False, + local_cache_dir=str(prep_temp_dir), ) - - # ------------ general setup -------------# - ani_species, ani_coordinates, _, mf_input = setup_two_methanes() - pairlist = Pairlist(only_unique_pairs=True) - pairs = pairlist( - mf_input.positions, - mf_input.atomic_subsystem_indices, + # load the original ani2x parameter set + potential.load_state_dict(torch.load(file_path)) + # compare to original ani2x dataset + atomic_energies = potential(mf_input)["per_atom_energy"] + modelforge_atomic_energies = ( + atomic_energies.flatten() + reference_ase.squeeze(0).flatten() ) - d_ij = pairs.d_ij - # ANI constants - radial_cutoff = 5.1 # radial_cutoff - radial_start = 0.8 - radial_dist_divisions = 16 - # ------------ Modelforge calculation ----------# - device = torch.device("cpu") + print(atomic_energies.flatten()) + print(torchani_atomic_energies.flatten() - reference_ase.flatten()) + + print(modelforge_atomic_energies) + print(torchani_atomic_energies.flatten()) - radial_symmetry_function = AniRadialBasisFunction( - radial_dist_divisions, - radial_cutoff * unit.angstrom, - radial_start * unit.angstrom, - ).to(device=device) + assert torch.allclose( + modelforge_atomic_energies, + torchani_atomic_energies.flatten(), + rtol=1e-3, + ) - cutoff_module = CosineCutoff(radial_cutoff * unit.angstrom).to(device=device) - rcut_ij = cutoff_module(d_ij) - calculated_rbf_output = radial_symmetry_function(d_ij) - calculated_rbf_output = calculated_rbf_output * rcut_ij +def test_ani_against_torchani_reference(): + import torch - # test that both ANI and MF obtain the same radial symmetry outpu - reference_rbf_output, ani_d_ij = ( - provide_reference_values_for_test_ani_test_compute_rsf_with_diagonal_batching() + # get input + species, coordinates, device, mf_input = setup_two_methanes() + + # ------------------------------------------ # + # setup modelforge potential + potential = setup_potential_for_test( + use="training", + potential_seed=42, + potential_name="ani2x", + jit=False, + local_cache_dir=str(prep_temp_dir), ) - assert torch.allclose(calculated_rbf_output, reference_rbf_output, atol=1e-4) + # load the original ani2x parameter set + potential.load_state_dict(torch.load(file_path)) + # compare to original ani2x dataset + atomic_energies = potential(mf_input)["per_atom_energy"] + assert torch.allclose( - ani_d_ij, d_ij.squeeze(1) * 10, atol=1e-4 - ) # NOTE: unit mismatch + atomic_energies, + torch.tensor( + [ + [0.0052], + [0.0181], + [0.0080], + [-0.0055], + [-0.0048], + [0.0052], + [0.0181], + [0.0080], + [-0.0055], + [-0.0048], + ] + ), + rtol=1e-2, + ) # that's the atomic energies for the two methane molecules obtained with torchani - assert calculated_rbf_output.shape == torch.Size([20, radial_dist_divisions]) + a = 7 -def test_compare_angular_symmetry_features(): - # Compare the calculated angular symmetry function output - # against the reference angular symmetry functino output +@pytest.mark.parametrize("mode", ["inference", "training"]) +def test_forward_and_backward(mode): + # Test modelforge ANI implementation + # Test forward and backward pass import torch - from modelforge.potential.utils import AngularSymmetryFunction, triple_by_molecule - from openff.units import unit - from modelforge.potential.models import Pairlist - - device = torch.device("cpu") - # set up relevant system properties - species, r, _, _ = setup_methane() - pairlist = Pairlist(only_unique_pairs=True).to(device=device) - pairs = pairlist(r[0], torch.tensor([0, 0, 0, 0, 0], device=device)) - d_ij = pairs.d_ij.squeeze(1) - r_ij = pairs.r_ij.squeeze(1) - - # reformat for input - species = species.flatten() - atom_index12 = pairs.pair_indices - # ANI constants - # for angular features - angular_cutoff = Rca = 3.5 # angular_cutoff - angular_start = 0.8 - angular_dist_divisions = 8 - - # get index in right order - even_closer_indices = (d_ij <= Rca).nonzero().flatten() - atom_index12 = atom_index12.index_select(1, even_closer_indices) - r_ij = r_ij.index_select(0, even_closer_indices) - central_atom_index, pair_index12, sign12 = triple_by_molecule(atom_index12) - vec12 = r_ij.index_select(0, pair_index12.view(-1)).view( - 2, -1, 3 - ) * sign12.unsqueeze(-1) - - # now use formated indices and inputs to calculate the - # angular terms, both with the modelforge AngularSymmetryFunction - # and with its implementation in torchani - - # ref value - from .precalculated_values import ( - provide_input_for_test_test_compare_angular_symmetry_features, + model = setup_potential_for_test( + use=mode, + potential_seed=42, + potential_name="ani2x", + simulation_environment="PyTorch", + use_training_mode_neighborlist=True, + jit=False, ) - reference_angular_feature_vector = ( - provide_input_for_test_test_compare_angular_symmetry_features() - ) + _, _, _, mf_input = setup_two_methanes() - # set up modelforge angular features - asf = AngularSymmetryFunction( - angular_cutoff * unit.angstrom, - angular_start * unit.angstrom, - angular_dist_divisions, - angle_sections=4, - ) - # NOTE: ANI works with Angstrom, modelforge with nanometer - # NOTE: ANI operates on a [nr_of_molecules, nr_of_atoms, 3] tensor - calculated_angular_feature_vector = asf(vec12 / 10) - # make sure that the output is the same - assert ( - calculated_angular_feature_vector.size() - == reference_angular_feature_vector.size() - ) + energy = model(mf_input) + derivative = torch.autograd.grad( + energy["per_system_energy"].sum(), mf_input.positions + )[0] + per_atom_force = -derivative - # NOTE: the order of the angular_feature_vector is not guaranteed - # as the triple_by_molecule function used to prepare the inputs does not use stable sorting. - # When stable sorting is used, the output is identical across platforms, but will not be - # used here as it is slower and the order of the output is not important in practrice. - # As such, to check for equivalence in a way that is not order dependent, we can just consider the sum. + # same input, same output assert torch.isclose( - torch.sum(calculated_angular_feature_vector), - torch.sum(reference_angular_feature_vector), - atol=1e-4, + energy["per_system_energy"][0], energy["per_system_energy"][1], rtol=1e-4 ) + assert torch.allclose(per_atom_force[0:5], per_atom_force[5:10], rtol=1e-4) -def test_compare_aev(): - """ - Compare the atomic enviornment vector generated by the reference implementation (torchani) and modelforge for the same input - """ +def test_representation(): + # Compare the reference radial symmetry function output against the the + # implemented radial symmetry function import torch - from .precalculated_values import provide_input_for_test_ani_test_compare_aev - - # methane input - species, coordinates, device, mf_input = setup_methane() - - # generate modelforge ani representation - from modelforge.potential import ANI2x - - # read default parameters - from modelforge.tests.test_models import load_configs - - # read default parameters - config = load_configs("ani2x", "qm9") - - # Extract parameters - potential_parameter = config["potential"].get("potential_parameter", {}) - - mf_model = ANI2x( - **config["potential"]["core_parameter"], - postprocessing_parameter=config["potential"]["postprocessing_parameter"], + from modelforge.potential import ( + AniRadialBasisFunction, + CosineAttenuationFunction, ) - # perform input checks - mf_model.input_preparation._input_checks(mf_input) - # prepare the input for the forward pass - pairlist_output = mf_model.input_preparation.prepare_inputs(mf_input) - nnp_input = mf_model.core_module._model_specific_input_preparation( - mf_input, pairlist_output + from openff.units import unit + from .precalculated_values import ( + provide_reference_values_for_test_ani_test_compare_rsf, ) - representation_module_output = mf_model.core_module.ani_representation_module( - nnp_input + + # set up relevant variables + d_ij = unit.Quantity( + torch.tensor([[3.5201], [2.6756], [2.1641], [3.0990], [4.5180]]), unit.angstrom ) + max_distance = unit.Quantity(5.0, unit.angstrom) + min_distance = unit.Quantity(0.8, unit.angstrom) + radial_dist_divisions = 8 - reference_aev = provide_input_for_test_ani_test_compare_aev() - # test for equivalence - assert torch.Size([5, 1008]) == representation_module_output.aevs.shape - # compare a selected subsection - assert torch.allclose( - reference_aev, representation_module_output.aevs[::2, :50:5], atol=1e-4 + # pass parameters to the radial symmetry function + rsf = AniRadialBasisFunction( + number_of_radial_basis_functions=radial_dist_divisions, + max_distance=max_distance.to(unit.nanometer).m, + min_distance=min_distance.to(unit.nanometer).m, ) + + calculated_rsf = rsf(d_ij.to(unit.nanometer).m) # torch.Size([5,1, 8]) + cutoff_module = CosineAttenuationFunction(max_distance.to(unit.nanometer).m) + + rcut_ij = cutoff_module(d_ij.to(unit.nanometer).m) # torch.Size([5]) + calculated_rsf = calculated_rsf * rcut_ij + + # get the precalculated output obtained from torchani for the same d_ij and + # cutoff values + reference_rsf = provide_reference_values_for_test_ani_test_compare_rsf() + assert torch.allclose(calculated_rsf, reference_rsf, rtol=1e-4) diff --git a/modelforge/tests/test_curation.py b/modelforge/tests/test_curation.py index d915bfcf..2cf2f051 100644 --- a/modelforge/tests/test_curation.py +++ b/modelforge/tests/test_curation.py @@ -127,6 +127,15 @@ def test_dict_to_hdf5(prep_temp_dir): id_key="name", ) + # test.hdf5 was generated in test_dict_to_hdf5 + files = list_files(str(prep_temp_dir), ".hdf5") + + # check to see if test.hdf5 is in the files + assert "test.hdf5" in files + + with pytest.raises(Exception): + list_files("/path/that/should/not/exist/", ".hdf5") + def test_series_dict_to_hdf5(prep_temp_dir): # generate an hdf5 file from simple test data @@ -215,17 +224,6 @@ def test_series_dict_to_hdf5(prep_temp_dir): assert records[i][key] == test_data[i][key] -def test_list_files(prep_temp_dir): - # test.hdf5 was generated in test_dict_to_hdf5 - files = list_files(str(prep_temp_dir), ".hdf5") - - # check to see if test.hdf5 is in the files - assert "test.hdf5" in files - - with pytest.raises(Exception): - list_files("/path/that/should/not/exist/", ".hdf5") - - def test_str_to_float(prep_temp_dir): val = str_to_float("1*^6") assert val == 1e6 @@ -347,9 +345,7 @@ def test_qm9_curation_parse_xyz(prep_temp_dir): assert data_dict_temp["energy_of_homo"] == [[-0.3877]] * unit.hartree assert data_dict_temp["energy_of_lumo"] == [[0.1171]] * unit.hartree assert data_dict_temp["lumo-homo_gap"] == [[0.5048]] * unit.hartree - assert ( - data_dict_temp["electronic_spatial_extent"] == [[35.3641]] * unit.angstrom**2 - ) + assert data_dict_temp["electronic_spatial_extent"] == [[35.3641]] * unit.angstrom**2 assert ( data_dict_temp["zero_point_vibrational_energy"] == [[0.044749]] * unit.hartree ) diff --git a/modelforge/tests/test_dataset.py b/modelforge/tests/test_dataset.py index d5cdb0b7..edb82c80 100644 --- a/modelforge/tests/test_dataset.py +++ b/modelforge/tests/test_dataset.py @@ -11,6 +11,12 @@ from modelforge.utils.prop import PropertyNames +@pytest.fixture(scope="session") +def prep_temp_dir(tmp_path_factory): + fn = tmp_path_factory.mktemp("test_dataset_temp") + return fn + + def test_dataset_imported(): """Sample test, will always pass so long as import statement worked.""" @@ -37,11 +43,12 @@ def test_dataset_basic_operations(): ), "atomic_subsystem_counts": atomic_subsystem_counts, "n_confs": n_confs, - "charges": torch.randint(-1, 2, torch.Size([total_confs])).numpy(), } property_names = PropertyNames( - "atomic_numbers", "geometry", "internal_energy_at_0K", "charges" + atomic_numbers="atomic_numbers", + positions="geometry", + E="internal_energy_at_0K", ) dataset = TorchDataset(input_data, property_names) assert len(dataset) == total_confs @@ -76,11 +83,15 @@ def test_dataset_basic_operations(): for conf_idx in range(len(dataset)): conf_data = dataset[conf_idx] - assert np.array_equal(conf_data.nnp_input.positions, geom_true[conf_idx]) + pos1 = geom_true[conf_idx] + pos2 = conf_data.nnp_input.positions + assert np.array_equal(pos2, pos1) assert np.array_equal( conf_data.nnp_input.atomic_numbers, atomic_numbers_true[conf_idx] ) - assert np.array_equal(conf_data.metadata.E, energy_true[conf_idx]) + assert np.array_equal( + conf_data.metadata.per_system_energy, energy_true[conf_idx] + ) for rec_idx in range(dataset.record_len()): assert np.array_equal( @@ -89,8 +100,16 @@ def test_dataset_basic_operations(): @pytest.mark.parametrize("dataset_name", _ImplementedDatasets.get_all_dataset_names()) -def test_different_properties_of_interest(dataset_name, dataset_factory, prep_temp_dir): +def test_get_properties(dataset_name, single_batch_with_batchsize, prep_temp_dir): + batch = single_batch_with_batchsize( + batch_size=16, dataset_name=dataset_name, local_cache_dir=str(prep_temp_dir) + ) + a = 7 + + +@pytest.mark.parametrize("dataset_name", _ImplementedDatasets.get_all_dataset_names()) +def test_different_properties_of_interest(dataset_name, dataset_factory, prep_temp_dir): local_cache_dir = str(prep_temp_dir) + "/data_test" data = _ImplementedDatasets.get_dataset_class( @@ -101,7 +120,7 @@ def test_different_properties_of_interest(dataset_name, dataset_factory, prep_te "geometry", "atomic_numbers", "internal_energy_at_0K", - "charges", + "dipole_moment", ] # spot check the processing of the yaml file assert data.gz_data_file["length"] == 1697917 @@ -130,6 +149,7 @@ def test_different_properties_of_interest(dataset_name, dataset_factory, prep_te "atomic_numbers", "wb97x_dz.energy", "wb97x_dz.forces", + "dipole_moment", ] data.properties_of_interest = [ @@ -152,20 +172,21 @@ def test_different_properties_of_interest(dataset_name, dataset_factory, prep_te "atomic_numbers", "dft_total_energy", "dft_total_force", - "mbis_charges", + "total_charge", + "scf_dipole", ] data.properties_of_interest = [ "dft_total_energy", "geometry", "atomic_numbers", - "mbis_charges", + "total_charge", ] assert data.properties_of_interest == [ "dft_total_energy", "geometry", "atomic_numbers", - "mbis_charges", + "total_charge", ] elif dataset_name == "PhAlkEthOH": assert data.properties_of_interest == [ @@ -174,6 +195,7 @@ def test_different_properties_of_interest(dataset_name, dataset_factory, prep_te "dft_total_energy", "dft_total_force", "total_charge", + "dipole_moment", ] data.properties_of_interest = [ @@ -194,12 +216,8 @@ def test_different_properties_of_interest(dataset_name, dataset_factory, prep_te raw_data_item = dataset[0] assert isinstance(raw_data_item, BatchData) assert len(raw_data_item.__dataclass_fields__) == 2 - assert ( - len(raw_data_item.nnp_input.__dataclass_fields__) == 5 - ) # 8 properties are returned - assert ( - len(raw_data_item.metadata.__dataclass_fields__) == 5 - ) # 8 properties are returned + assert len(raw_data_item.nnp_input.__slots__) == 8 # 8 properties are returned + assert len(raw_data_item.metadata.__slots__) == 6 # 6 properties are returned @pytest.mark.parametrize("dataset_name", ["QM9"]) @@ -216,7 +234,6 @@ def test_file_existence_after_initialization( ) with contextlib.suppress(FileNotFoundError): - os.remove(f"{local_cache_dir}/{data.gz_data_file['name']}") os.remove(f"{local_cache_dir}/{data.hdf5_data_file['name']}") os.remove(f"{local_cache_dir}/{data.processed_data_file['name']}") @@ -234,7 +251,7 @@ def test_file_existence_after_initialization( def test_caching(prep_temp_dir): import contextlib - local_cache_dir = str(prep_temp_dir) + "/data_test" + local_cache_dir = str(prep_temp_dir) from modelforge.dataset.qm9 import QM9Dataset data = QM9Dataset(version_select="nc_1000_v0", local_cache_dir=local_cache_dir) @@ -306,10 +323,11 @@ def test_caching(prep_temp_dir): def test_metadata_validation(prep_temp_dir): - """When we generate an .npz file, we also write out metadata in a .json file which is used - to validate if we can use .npz file, or we need to regenerate it.""" + """When we generate an .npz file, we also write out metadata in a .json file + which is used to validate if we can use .npz file, or we need to + regenerate it.""" - local_cache_dir = str(prep_temp_dir) + "/data_test" + local_cache_dir = str(prep_temp_dir) from modelforge.dataset.qm9 import QM9Dataset @@ -331,7 +349,12 @@ def test_metadata_validation(prep_temp_dir): assert data._metadata_validation("qm9_test.json", local_cache_dir) == False metadata = { - "data_keys": ["atomic_numbers", "internal_energy_at_0K", "geometry", "charges"], + "data_keys": [ + "atomic_numbers", + "internal_energy_at_0K", + "geometry", + "dipole_moment", + ], "hdf5_checksum": "305a0602860f181fafa75f7c7e3e6de4", "hdf5_gz_checkusm": "dc8ada0d808d02c699daf2000aff1fe9", "date_generated": "2024-04-11 14:05:14.297305", @@ -339,9 +362,14 @@ def test_metadata_validation(prep_temp_dir): import json + # create local_cache_dir if not already present + import os + + os.makedirs(local_cache_dir, exist_ok=True) + with open( f"{local_cache_dir}/qm9_test.json", - "w", + "w+", ) as f: json.dump(metadata, f) @@ -435,18 +463,19 @@ def test_data_item_format_of_datamodule( """Test the format of individual data items in the dataset.""" from typing import Dict - local_cache_dir = str(prep_temp_dir) + "/data_test" + local_cache_dir = str(prep_temp_dir) dm = datamodule_factory( dataset_name=dataset_name, batch_size=512, + local_cache_dir=local_cache_dir, ) raw_data_item = dm.torch_dataset[0] assert isinstance(raw_data_item, BatchData) assert isinstance(raw_data_item.nnp_input.atomic_numbers, torch.Tensor) assert isinstance(raw_data_item.nnp_input.positions, torch.Tensor) - assert isinstance(raw_data_item.metadata.E, torch.Tensor) + assert isinstance(raw_data_item.metadata.per_system_energy, torch.Tensor) assert ( raw_data_item.nnp_input.atomic_numbers.shape[0] @@ -457,25 +486,60 @@ def test_data_item_format_of_datamodule( from modelforge.potential import _Implemented_NNPs -@pytest.mark.parametrize("model_name", _Implemented_NNPs.get_all_neural_network_names()) -def test_dataset_neighborlist(model_name, single_batch_with_batchsize_64): +@pytest.mark.parametrize("dataset_name", _ImplementedDatasets.get_all_dataset_names()) +def test_removal_of_self_energy(dataset_name, datamodule_factory, prep_temp_dir): + # test the self energy calculation on the QM9 dataset + from modelforge.dataset.utils import FirstComeFirstServeSplittingStrategy + + # prepare reference value + dm = datamodule_factory( + dataset_name=dataset_name, + batch_size=512, + splitting_strategy=FirstComeFirstServeSplittingStrategy(), + version_select="nc_1000_v0", + remove_self_energies=False, + regenerate_dataset_statistic=True, + local_cache_dir=str(prep_temp_dir), + ) + + first_entry_with_ase = dm.train_dataset[0].metadata.per_system_energy + + # prepare reference value + dm = datamodule_factory( + dataset_name=dataset_name, + batch_size=512, + splitting_strategy=FirstComeFirstServeSplittingStrategy(), + version_select="nc_1000_v0", + remove_self_energies=True, + local_cache_dir=str(prep_temp_dir), + ) + + atomic_numbers = dm.train_dataset[0].nnp_input.atomic_numbers + first_entry_without_ase = dm.train_dataset[0].metadata.per_system_energy + assert not torch.allclose(first_entry_with_ase, first_entry_without_ase) + + +@pytest.mark.parametrize( + "potential_name", _Implemented_NNPs.get_all_neural_network_names() +) +def test_dataset_neighborlist( + potential_name, single_batch_with_batchsize, prep_temp_dir +): """Test the neighborlist.""" - nnp_input = single_batch_with_batchsize_64.nnp_input + batch = single_batch_with_batchsize(64, "QM9", str(prep_temp_dir)) + nnp_input = batch.nnp_input # test that the neighborlist is correctly generated - from modelforge.tests.test_models import load_configs - - # read default parameters - config = load_configs(f"{model_name}", "qm9") + from modelforge.tests.helper_functions import setup_potential_for_test - # Extract parameters - from modelforge.potential.models import NeuralNetworkPotentialFactory - # initialize model - model = NeuralNetworkPotentialFactory.generate_model( + model = setup_potential_for_test( use="inference", + potential_seed=42, + potential_name="ani2x", simulation_environment="PyTorch", - model_parameter=config["potential"], + use_training_mode_neighborlist=True, + local_cache_dir=str(prep_temp_dir), ) model(nnp_input) @@ -564,10 +628,12 @@ def test_dataset_neighborlist(model_name, single_batch_with_batchsize_64): @pytest.mark.parametrize("dataset_name", _ImplementedDatasets.get_all_dataset_names()) -def test_dataset_generation(dataset_name, datamodule_factory): +def test_dataset_generation(dataset_name, datamodule_factory, prep_temp_dir): """Test the splitting of the dataset.""" - dataset = datamodule_factory(dataset_name=dataset_name) + dataset = datamodule_factory( + dataset_name=dataset_name, local_cache_dir=str(prep_temp_dir) + ) train_dataloader = dataset.train_dataloader() val_dataloader = dataset.val_dataloader() test_dataloader = dataset.test_dataloader() @@ -577,9 +643,9 @@ def test_dataset_generation(dataset_name, datamodule_factory): # this isn't set when dataset is in 'fit' mode pass - # the dataloader automatically splits and batches the dataset - # for the training set it batches the 800 training datapoints (of 1000 total) in 13 batches - # all with 64 points until the last which has 32 + # the dataloader automatically splits and batches the dataset for the + # training set it batches the 800 training datapoints (of 1000 total) in 13 + # batches all with 64 points until the last which has 32 assert len(train_dataloader) == 13 # nr of batches batch_data = [v_ for v_ in train_dataloader] @@ -614,17 +680,20 @@ def test_dataset_generation(dataset_name, datamodule_factory): ) @pytest.mark.parametrize("dataset_name", ["QM9"]) def test_dataset_splitting( - splitting_strategy, dataset_name, datamodule_factory, get_dataset_container_fix + splitting_strategy, + dataset_name, + datamodule_factory, + get_dataset_container_fix, + prep_temp_dir, ): """Test random_split on the the dataset.""" - from modelforge.dataset import DataModule - dm = datamodule_factory( dataset_name=dataset_name, batch_size=512, splitting_strategy=splitting_strategy(), version_select="nc_1000_v0", remove_self_energies=False, + local_cache_dir=str(prep_temp_dir), ) train_dataset, val_dataset, test_dataset = ( @@ -633,7 +702,7 @@ def test_dataset_splitting( dm.test_dataset, ) - energy = train_dataset[0].metadata.E.item() + energy = train_dataset[0].metadata.per_system_energy.item() dataset_to_test = get_dataset_container_fix(dataset_name) if splitting_strategy == RandomSplittingStrategy: assert np.isclose(energy, dataset_to_test.expected_E_random_split) @@ -646,6 +715,7 @@ def test_dataset_splitting( splitting_strategy=splitting_strategy(split=[0.6, 0.3, 0.1]), version_select="nc_1000_v0", remove_self_energies=False, + local_cache_dir=str(prep_temp_dir), ) train_dataset2, val_dataset2, test_dataset2 = ( @@ -683,7 +753,7 @@ def test_dataset_downloader(dataset_name, dataset_factory, prep_temp_dir): """ Test the DatasetDownloader functionality. """ - local_cache_dir = str(prep_temp_dir) + "/data_test" + local_cache_dir = str(prep_temp_dir) dataset = dataset_factory( dataset_name=dataset_name, local_cache_dir=local_cache_dir @@ -695,7 +765,7 @@ def test_dataset_downloader(dataset_name, dataset_factory, prep_temp_dir): @pytest.mark.parametrize("dataset_name", _ImplementedDatasets.get_all_dataset_names()) -def test_numpy_dataset_assignment(dataset_name): +def test_numpy_dataset_assignment(dataset_name, prep_temp_dir): """ Test if the numpy_dataset attribute is correctly assigned after processing or loading. """ @@ -703,7 +773,7 @@ def test_numpy_dataset_assignment(dataset_name): factory = DatasetFactory() data = _ImplementedDatasets.get_dataset_class(dataset_name)( - version_select="nc_1000_v0" + version_select="nc_1000_v0", local_cache_dir=str(prep_temp_dir) ) factory._load_or_process_data(data) @@ -711,8 +781,9 @@ def test_numpy_dataset_assignment(dataset_name): assert isinstance(data.numpy_data, np.lib.npyio.NpzFile) -def test_energy_postprocessing(): - # setup test dataset +def test_energy_postprocessing(prep_temp_dir): + # test that the mean and stddev of the dataset + # are correct from modelforge.dataset.dataset import DataModule # test the self energy calculation on the QM9 dataset @@ -727,17 +798,18 @@ def test_energy_postprocessing(): splitting_strategy=FirstComeFirstServeSplittingStrategy(), remove_self_energies=True, regenerate_dataset_statistic=True, + local_cache_dir=str(prep_temp_dir), ) dm.prepare_data() dm.setup() batch = next(iter(dm.val_dataloader())) - unnormalized_E = batch.metadata.E.numpy().flatten() + unnormalized_E = batch.metadata.per_system_energy.numpy().flatten() import numpy as np # check that normalized energies are correct assert torch.allclose( - batch.metadata.E.squeeze(1), + batch.metadata.per_system_energy.squeeze(1), torch.tensor( [ [ @@ -792,8 +864,7 @@ def test_energy_postprocessing(): @pytest.mark.parametrize("dataset_name", ["QM9"]) -def test_function_of_self_energy(dataset_name, datamodule_factory): - +def test_function_of_self_energy(dataset_name, datamodule_factory, prep_temp_dir): # test the self energy calculation on the QM9 dataset from modelforge.dataset.utils import FirstComeFirstServeSplittingStrategy @@ -805,9 +876,10 @@ def test_function_of_self_energy(dataset_name, datamodule_factory): version_select="nc_1000_v0", remove_self_energies=False, regenerate_dataset_statistic=True, + local_cache_dir=str(prep_temp_dir), ) - methane_energy_reference = float(dm.train_dataset[0].metadata.E) + methane_energy_reference = float(dm.train_dataset[0].metadata.per_system_energy) assert np.isclose(methane_energy_reference, -106277.4161) # Scenario 1: dataset contains self energies @@ -816,6 +888,7 @@ def test_function_of_self_energy(dataset_name, datamodule_factory): batch_size=512, splitting_strategy=FirstComeFirstServeSplittingStrategy(), regenerate_dataset_statistic=True, + local_cache_dir=str(prep_temp_dir), ) # it is saved in the dataset statistics @@ -847,6 +920,7 @@ def test_function_of_self_energy(dataset_name, datamodule_factory): remove_self_energies=True, version_select="nc_1000_v0", regenerate_dataset_statistic=True, + local_cache_dir=str(prep_temp_dir), ) # it is saved in the dataset statistics @@ -891,6 +965,7 @@ def test_function_of_self_energy(dataset_name, datamodule_factory): remove_self_energies=True, version_select="nc_1000_v0", regenerate_dataset_statistic=True, + local_cache_dir=str(prep_temp_dir), ) # it is saved in the dataset statistics import toml @@ -909,12 +984,13 @@ def test_function_of_self_energy(dataset_name, datamodule_factory): remove_self_energies=True, version_select="nc_1000_v0", regenerate_dataset_statistic=True, + local_cache_dir=str(prep_temp_dir), ) # Extract the first molecule (methane) # double check that it is methane methane_atomic_indices = dm.train_dataset[0].nnp_input.atomic_numbers # extract energy - methane_energy_offset = dm.train_dataset[0].metadata.E + methane_energy_offset = dm.train_dataset[0].metadata.per_system_energy if regression is False: # checking that the offset energy is actually correct for methane assert torch.isclose( @@ -938,3 +1014,236 @@ def test_function_of_self_energy(dataset_name, datamodule_factory): ) # compare this to the energy without postprocessing assert np.isclose(methane_energy_reference, methane_energy_offset + methane_ase) + + +def test_shifting_center_of_mass_to_origin(prep_temp_dir): + local_cache_dir = str(prep_temp_dir) + + from modelforge.dataset.dataset import initialize_datamodule + from openff.units.elements import MASSES + + import torch + + # first check a molecule not centered at the origin + dm = initialize_datamodule( + "QM9", + version_select="latest_test", + shift_center_of_mass_to_origin=False, + local_cache_dir=local_cache_dir, + ) + start_idx = dm.torch_dataset.single_atom_start_idxs_by_conf[0] + end_idx = dm.torch_dataset.single_atom_end_idxs_by_conf[0] + + from openff.units.elements import MASSES + + pos = dm.torch_dataset.properties_of_interest["positions"][start_idx:end_idx] + + atomic_masses = torch.Tensor( + [ + MASSES[atomic_number].m + for atomic_number in dm.torch_dataset.properties_of_interest[ + "atomic_numbers" + ][start_idx:end_idx].tolist() + ] + ) + molecule_mass = torch.sum(atomic_masses) + + # I'm using einsum, so let us check it manually + + x = 0 + y = 0 + z = 0 + for i in range(0, pos.shape[0]): + x += atomic_masses[i] * pos[i][0] + y += atomic_masses[i] * pos[i][1] + z += atomic_masses[i] * pos[i][2] + + x = x / molecule_mass + y = y / molecule_mass + z = z / molecule_mass + + com = torch.Tensor([x, y, z]) + + assert torch.allclose(com, torch.Tensor([-0.0013, 0.1086, 0.0008]), atol=1e-4) + + # make sure that we do shift to the origin; we can do the whole dataset + + dm = initialize_datamodule( + "QM9", + version_select="latest_test", + shift_center_of_mass_to_origin=True, + local_cache_dir=local_cache_dir, + ) + + for conf_id in range(0, len(dm.torch_dataset)): + start_idx = dm.torch_dataset.single_atom_start_idxs_by_conf[conf_id] + end_idx = dm.torch_dataset.single_atom_end_idxs_by_conf[conf_id] + + # grab the positions that should be shifted + pos = dm.torch_dataset.properties_of_interest["positions"][start_idx:end_idx] + + atomic_masses = torch.Tensor( + [ + MASSES[atomic_number].m + for atomic_number in dm.torch_dataset.properties_of_interest[ + "atomic_numbers" + ][start_idx:end_idx].tolist() + ] + ) + molecule_mass = torch.sum(atomic_masses) + + x = 0 + y = 0 + z = 0 + for i in range(0, pos.shape[0]): + x += atomic_masses[i] * pos[i][0] + y += atomic_masses[i] * pos[i][1] + z += atomic_masses[i] * pos[i][2] + + x = x / molecule_mass + y = y / molecule_mass + z = z / molecule_mass + + com = torch.Tensor([x, y, z]) + assert torch.allclose(com, torch.Tensor([0.0, 0.0, 0.0]), atol=1e-4) + + +def test_shifting_center_of_mass_to_origin(prep_temp_dir): + local_cache_dir = str(prep_temp_dir) + + from modelforge.dataset.dataset import initialize_datamodule + from openff.units.elements import MASSES + + import torch + + # first check a molecule not centered at the origin + dm = initialize_datamodule( + "QM9", + version_select="nc_1000_v0", + shift_center_of_mass_to_origin=False, + local_cache_dir=local_cache_dir, + ) + start_idx = dm.torch_dataset.single_atom_start_idxs_by_conf[0] + end_idx = dm.torch_dataset.single_atom_end_idxs_by_conf[0] + + from openff.units.elements import MASSES + + pos = dm.torch_dataset.properties_of_interest["positions"][start_idx:end_idx] + + atomic_masses = torch.Tensor( + [ + MASSES[atomic_number].m + for atomic_number in dm.torch_dataset.properties_of_interest[ + "atomic_numbers" + ][start_idx:end_idx].tolist() + ] + ) + molecule_mass = torch.sum(atomic_masses) + + # I'm using einsum, so let us check it manually + + x = 0 + y = 0 + z = 0 + for i in range(0, pos.shape[0]): + x += atomic_masses[i] * pos[i][0] + y += atomic_masses[i] * pos[i][1] + z += atomic_masses[i] * pos[i][2] + + x = x / molecule_mass + y = y / molecule_mass + z = z / molecule_mass + + com = torch.Tensor([x, y, z]) + + assert torch.allclose(com, torch.Tensor([-0.0013, 0.1086, 0.0008]), atol=1e-4) + + # make sure that we do shift to the origin; we can do the whole dataset + + dm = initialize_datamodule( + "PhAlkEthOH", + version_select="latest_test", + shift_center_of_mass_to_origin=True, + local_cache_dir=local_cache_dir, + ) + dm_no_shift = initialize_datamodule( + "PhAlkEthOH", + version_select="latest_test", + shift_center_of_mass_to_origin=False, + local_cache_dir=local_cache_dir, + ) + for conf_id in range(0, len(dm.torch_dataset)): + start_idx_mol = dm.torch_dataset.series_atom_start_idxs_by_conf[conf_id] + end_idx_mol = dm.torch_dataset.series_atom_start_idxs_by_conf[conf_id + 1] + + start_idx = dm.torch_dataset.single_atom_start_idxs_by_conf[conf_id] + end_idx = dm.torch_dataset.single_atom_end_idxs_by_conf[conf_id] + + # grab the positions that should be shifted + pos = dm.torch_dataset.properties_of_interest["positions"][ + start_idx_mol:end_idx_mol + ] + + pos_original = dm_no_shift.torch_dataset.properties_of_interest["positions"][ + start_idx_mol:end_idx_mol + ] + atomic_masses = torch.Tensor( + [ + MASSES[atomic_number].m + for atomic_number in dm.torch_dataset.properties_of_interest[ + "atomic_numbers" + ][start_idx:end_idx].tolist() + ] + ) + molecule_mass = torch.sum(atomic_masses) + + x = 0 + y = 0 + z = 0 + + x_ns = 0 + y_ns = 0 + z_ns = 0 + for i in range(0, pos.shape[0]): + x += atomic_masses[i] * pos[i][0] + y += atomic_masses[i] * pos[i][1] + z += atomic_masses[i] * pos[i][2] + + x_ns += atomic_masses[i] * pos_original[i][0] + y_ns += atomic_masses[i] * pos_original[i][1] + z_ns += atomic_masses[i] * pos_original[i][2] + + x = x / molecule_mass + y = y / molecule_mass + z = z / molecule_mass + + x_ns = x_ns / molecule_mass + y_ns = y_ns / molecule_mass + z_ns = z_ns / molecule_mass + + com = torch.Tensor([x, y, z]) + com_ns = torch.Tensor([x_ns, y_ns, z_ns]) + + pos_ns = pos.clone() + for i in range(0, pos_ns.shape[0]): + pos_ns[i] = pos_original[i] - com_ns + + assert torch.allclose(com, torch.Tensor([0.0, 0.0, 0.0]), atol=1e-4) + + # I don't expect to be exactly the same because use of einsum in the main code + # but still should be pretty close + assert torch.allclose(pos, pos_original - com_ns, atol=1e-3) + assert torch.allclose(pos, pos_ns, atol=1e-3) + + from modelforge.potential.neighbors import NeighborListForTraining + + nnp_input = dm.torch_dataset[conf_id].nnp_input + nnp_input_ns = dm_no_shift.torch_dataset[conf_id].nnp_input + + nlist = NeighborListForTraining(cutoff=0.5) + + pairs = nlist(nnp_input) + pairs_ns = nlist(nnp_input_ns) + + assert torch.allclose(pairs.r_ij, pairs_ns.r_ij, atol=1e-4) + assert torch.allclose(pairs.d_ij, pairs_ns.d_ij, atol=1e-4) diff --git a/modelforge/tests/test_modelforge.py b/modelforge/tests/test_modelforge.py index fbb03dfd..5e1c1929 100644 --- a/modelforge/tests/test_modelforge.py +++ b/modelforge/tests/test_modelforge.py @@ -14,5 +14,3 @@ def test_modelforge_imported(): """Sample test, will always pass so long as import statement worked.""" print("importing ", modelforge.__name__) assert "modelforge" in sys.modules - - diff --git a/modelforge/tests/test_models.py b/modelforge/tests/test_models.py deleted file mode 100644 index 1133697e..00000000 --- a/modelforge/tests/test_models.py +++ /dev/null @@ -1,1088 +0,0 @@ -import pytest - -from modelforge.potential import _Implemented_NNPs -from modelforge.dataset import _ImplementedDatasets -from modelforge.potential import NeuralNetworkPotentialFactory - - -def load_configs(model_name: str, dataset_name: str): - from modelforge.tests.data import ( - potential_defaults, - training_defaults, - dataset_defaults, - ) - from importlib import resources - from modelforge.train.training import return_toml_config - - potential_path = resources.files(potential_defaults) / f"{model_name.lower()}.toml" - dataset_path = resources.files(dataset_defaults) / f"{dataset_name}.toml" - training_path = resources.files(training_defaults) / "default.toml" - - return return_toml_config( - potential_path=potential_path, - dataset_path=dataset_path, - training_path=training_path, - ) - - -@pytest.mark.parametrize("model_name", _Implemented_NNPs.get_all_neural_network_names()) -def test_JAX_wrapping(model_name, single_batch_with_batchsize_64): - from modelforge.potential.models import ( - NeuralNetworkPotentialFactory, - ) - - # read default parameters - config = load_configs(f"{model_name.lower()}", "qm9") - - # inference model - model = NeuralNetworkPotentialFactory.generate_model( - use="inference", - simulation_environment="JAX", - model_parameter=config["potential"], - ) - - assert "JAX" in str(type(model)) - nnp_input = single_batch_with_batchsize_64.nnp_input.as_jax_namedtuple() - out = model(nnp_input)["per_molecule_energy"] - import jax - - grad_fn = jax.grad(lambda pos: out.sum()) # Create a gradient function - forces = -grad_fn( - nnp_input.positions - ) # Evaluate gradient function and apply negative sign - - -@pytest.mark.parametrize("model_name", _Implemented_NNPs.get_all_neural_network_names()) -@pytest.mark.parametrize("simulation_environment", ["JAX", "PyTorch"]) -def test_model_factory(model_name, simulation_environment): - from modelforge.potential.models import ( - NeuralNetworkPotentialFactory, - ) - from modelforge.train.training import TrainingAdapter - - # read default parameters - config = load_configs(f"{model_name.lower()}", "qm9") - - # inference model - model = NeuralNetworkPotentialFactory.generate_model( - use="inference", - simulation_environment=simulation_environment, - model_parameter=config["potential"], - ) - assert ( - model_name.upper() in str(type(model)).upper() - or "JAX" in str(type(model)).upper() - ) - - # Extract parameters - training_parameter = config["training"].get("training_parameter", {}) - # training model - model = NeuralNetworkPotentialFactory.generate_model( - use="training", - simulation_environment=simulation_environment, - model_parameter=config["potential"], - training_parameter=training_parameter, - ) - assert type(model) == TrainingAdapter - - -def test_energy_scaling_and_offset(): - # setup test dataset - from modelforge.potential.ani import ANI2x - from modelforge.dataset.dataset import DataModule - - import torch - - # prepare reference value - # get methane input - # test the self energy calculation on the QM9 dataset - from modelforge.dataset.utils import FirstComeFirstServeSplittingStrategy - - # prepare reference value - dataset = DataModule( - name="QM9", - batch_size=1, - version_select="nc_1000_v0", - splitting_strategy=FirstComeFirstServeSplittingStrategy(), - remove_self_energies=True, - regression_ase=False, - ) - dataset.prepare_data() - dataset.setup() - # get methane input - methane = next(iter(dataset.train_dataloader(shuffle=False))).nnp_input - # load dataset statistic - import toml - - dataset_statistic = toml.load(dataset.dataset_statistic_filename) - # -------------------------------# - # initialize model without any postprocessing - # -------------------------------# - config = load_configs("ani2x", "qm9") - - torch.manual_seed(42) - model = ANI2x( - **config["potential"]["core_parameter"], - postprocessing_parameter=config["potential"]["postprocessing_parameter"], - ) - output_no_postprocessing = model(methane) - # -------------------------------# - # Scale output - - torch.manual_seed(42) - model = ANI2x( - **config["potential"]["core_parameter"], - postprocessing_parameter=config["potential"]["postprocessing_parameter"], - dataset_statistic=dataset_statistic, - ) - scaled_output = model(methane) - - # make sure that the scaled output equals the unscaled output - from openff.units import unit - - mean = unit.Quantity( - dataset_statistic["training_dataset_statistics"]["per_atom_energy_mean"] - ).m - stddev = unit.Quantity( - dataset_statistic["training_dataset_statistics"]["per_atom_energy_stddev"] - ).m - - compare_to = output_no_postprocessing["per_atom_energy"] * stddev + mean - assert torch.allclose(scaled_output["per_atom_energy"], compare_to) - - # -------------------------------# - # Calculate atomic self energies - - # modify postprocessing parameters - config["potential"]["postprocessing_parameter"][ - "general_postprocessing_operation" - ] = {"calculate_molecular_self_energy": True} - - model = ANI2x( - **config["potential"]["core_parameter"], - postprocessing_parameter=config["potential"]["postprocessing_parameter"], - dataset_statistic=dataset_statistic, - ) - - output_with_molecular_self_energies = model(methane) - - # make sure that the raw prediction is the same - assert torch.isclose( - output_with_molecular_self_energies["per_molecule_self_energy"], - torch.tensor([-104620.5859]), - ) - - -@pytest.mark.parametrize("model_name", _Implemented_NNPs.get_all_neural_network_names()) -def test_state_dict_saving_and_loading(model_name): - from modelforge.potential import NeuralNetworkPotentialFactory - import torch - - # read default parameters - config = load_configs(f"{model_name.lower()}", "qm9") - - # Extract parameters - training_parameter = config["training"].get("training_parameter", {}) - - model1 = NeuralNetworkPotentialFactory.generate_model( - use="training", - simulation_environment="PyTorch", - model_parameter=config["potential"], - training_parameter=training_parameter, - ) - torch.save(model1.state_dict(), "model.pth") - - model2 = NeuralNetworkPotentialFactory.generate_model( - use="inference", - simulation_environment="PyTorch", - model_parameter=config["potential"], - ) - model2.load_state_dict(torch.load("model.pth")) - - -@pytest.mark.parametrize("model_name", _Implemented_NNPs.get_all_neural_network_names()) -def test_dataset_statistic(model_name): - # Test that the scaling parmaeters are propagated from the dataset to the - # training model and then via the state_dict to the inference model - - from modelforge.dataset.dataset import DataModule - from modelforge.dataset.utils import FirstComeFirstServeSplittingStrategy - - # read default parameters - config = load_configs(f"{model_name.lower()}", "qm9") - - # Extract parameters - training_parameter = config["training"].get("training_parameter", {}) - - # test the self energy calculation on the QM9 dataset - dataset = DataModule( - name="QM9", - batch_size=64, - version_select="nc_1000_v0", - splitting_strategy=FirstComeFirstServeSplittingStrategy(), - remove_self_energies=True, - regression_ase=False, - regenerate_dataset_statistic=True, - ) - dataset.prepare_data() - dataset.setup() - - import toml - from openff.units import unit - - # load dataset stastics from file - dataset_statistic = toml.load(dataset.dataset_statistic_filename) - - # extract value to compare against - toml_E_i_mean = unit.Quantity( - dataset_statistic["training_dataset_statistics"]["per_atom_energy_mean"] - ).m - - # set up training model - training_adapter = NeuralNetworkPotentialFactory.generate_model( - use="training", - simulation_environment="PyTorch", - model_parameter=config["potential"], - training_parameter=training_parameter, - dataset_statistic=dataset_statistic, - ) - import torch - import numpy as np - - print(training_adapter.model.postprocessing.dataset_statistic) - # check that the per_atom_energy_mean is the same than in the dataset statistics - assert np.isclose( - toml_E_i_mean, - unit.Quantity( - training_adapter.model.postprocessing.dataset_statistic[ - "training_dataset_statistics" - ]["per_atom_energy_mean"] - ).m, - ) - - torch.save(training_adapter.state_dict(), "model.pth") - - # NOTE: we are passing dataset statistics explicit to the constructor - # this is not saved with the state_dict - model = NeuralNetworkPotentialFactory.generate_model( - use="inference", - simulation_environment="PyTorch", - model_parameter=config["potential"], - dataset_statistic=dataset_statistic, - ) - model.load_state_dict(torch.load("model.pth")) - - a = 7 - - assert np.isclose( - toml_E_i_mean, - unit.Quantity( - model.postprocessing.dataset_statistic["training_dataset_statistics"][ - "per_atom_energy_mean" - ] - ).m, - ) - - -@pytest.mark.parametrize("model_name", _Implemented_NNPs.get_all_neural_network_names()) -def test_energy_between_simulation_environments( - model_name, single_batch_with_batchsize_64 -): - # compare that the energy is the same for the JAX and PyTorch Model - import numpy as np - import torch - - nnp_input = single_batch_with_batchsize_64.nnp_input - # test the forward pass through each of the models - # cast input and model to torch.float64 - # read default parameters - config = load_configs(f"{model_name.lower()}", "qm9") - - # Setup loss - from modelforge.train.training import return_toml_config - - torch.manual_seed(42) - model = NeuralNetworkPotentialFactory.generate_model( - use="inference", - simulation_environment="PyTorch", - model_parameter=config["potential"], - ) - - output_torch = model(nnp_input)["per_molecule_energy"] - - torch.manual_seed(42) - model = NeuralNetworkPotentialFactory.generate_model( - use="inference", - simulation_environment="JAX", - model_parameter=config["potential"], - ) - nnp_input = nnp_input.as_jax_namedtuple() - output_jax = model(nnp_input)["per_molecule_energy"] - - # test tat we get an energie per molecule - assert np.isclose(output_torch.sum().detach().numpy(), output_jax.sum()) - - -@pytest.mark.parametrize("model_name", _Implemented_NNPs.get_all_neural_network_names()) -@pytest.mark.parametrize("dataset_name", _ImplementedDatasets.get_all_dataset_names()) -def test_forward_pass_with_all_datasets(model_name, dataset_name, datamodule_factory): - """Test forward pass with all datasets.""" - import torch - - if dataset_name.lower().startswith("spice"): - print("using subset") - dataset = datamodule_factory( - dataset_name=dataset_name, version_select="nc_1000_v0_HCNOFClS" - ) - else: - dataset = datamodule_factory(dataset_name=dataset_name) - - train_dataloader = dataset.train_dataloader() - batch = next(iter(train_dataloader)) - - config = load_configs(f"{model_name.lower()}", dataset_name.lower()) - - from modelforge.potential.models import NeuralNetworkPotentialFactory - import toml - - dataset_statistic = toml.load(dataset.dataset_statistic_filename) - - model = NeuralNetworkPotentialFactory.generate_model( - use="inference", - simulation_environment="PyTorch", - model_parameter=config["potential"], - dataset_statistic=dataset_statistic, - ) - output = model(batch.nnp_input) - - # test that the output has the following keys and follwing dim - assert "per_molecule_energy" in output - assert "per_atom_energy" in output - - assert output["per_molecule_energy"].shape[0] == 64 - assert output["per_atom_energy"].shape == batch.nnp_input.atomic_numbers.shape - - pair_list = batch.nnp_input.pair_list - # pairlist is in ascending order in row 0 - assert torch.all(pair_list[0, 1:] >= pair_list[0, :-1]) - - -@pytest.mark.parametrize("model_name", _Implemented_NNPs.get_all_neural_network_names()) -@pytest.mark.parametrize("simulation_environment", ["JAX", "PyTorch"]) -def test_forward_pass( - model_name, simulation_environment, single_batch_with_batchsize_64 -): - # this test sends a single batch from different datasets through the model - import torch - - nnp_input = single_batch_with_batchsize_64.nnp_input - - # read default parameters - config = load_configs(f"{model_name.lower()}", "qm9") - nr_of_mols = nnp_input.atomic_subsystem_indices.unique().shape[0] - - # test the forward pass through each of the models - model = NeuralNetworkPotentialFactory.generate_model( - use="inference", - simulation_environment=simulation_environment, - model_parameter=config["potential"], - ) - if "JAX" in str(type(model)): - nnp_input = nnp_input.as_jax_namedtuple() - - output = model(nnp_input) - - # test that we get an energie per molecule - assert len(output["per_molecule_energy"]) == nr_of_mols - - # the batch consists of methane (CH4) and amamonium (NH3) - # which have chemically equivalent hydrogens at the minimum geometry. - # This has to be reflected in the atomic energies E_i, which - # have to be equal for all hydrogens - if "JAX" not in str(type(model)): - # assert that the following tensor has equal values for dim=0 index 1 to 4 and 6 to 8 - assert torch.allclose( - output["per_atom_energy"][1:4], output["per_atom_energy"][1], atol=1e-5 - ) - assert torch.allclose( - output["per_atom_energy"][6:8], output["per_atom_energy"][6], atol=1e-5 - ) - - # make sure that the total energy is \sum E_i - assert torch.allclose( - output["per_molecule_energy"][0], - output["per_atom_energy"][0:5].sum(dim=0), - atol=1e-5, - ) - assert torch.allclose( - output["per_molecule_energy"][1], - output["per_atom_energy"][5:9].sum(dim=0), - atol=1e-5, - ) - - -@pytest.mark.parametrize("model_name", _Implemented_NNPs.get_all_neural_network_names()) -def test_calculate_energies_and_forces(model_name, single_batch_with_batchsize_64): - """ - Test the calculation of energies and forces for a molecule. - """ - import torch - - # read default parameters - config = load_configs(f"{model_name.lower()}", "qm9") - - # Extract parameters - training_parameter = config["training"].get("training_parameter", {}) - - # get batch - nnp_input = single_batch_with_batchsize_64.nnp_input - - # test the pass through each of the models - torch.manual_seed(42) - model_inference = NeuralNetworkPotentialFactory.generate_model( - use="inference", - model_parameter=config["potential"], - ) - E_inference = model_inference(nnp_input)["per_molecule_energy"] - - # backpropagation - F_inference = -torch.autograd.grad( - E_inference.sum(), nnp_input.positions, create_graph=True, retain_graph=True - )[0] - - # make sure that dimension are as expected - nr_of_mols = nnp_input.atomic_subsystem_indices.unique().shape[0] - nr_of_atoms_per_batch = nnp_input.atomic_subsystem_indices.shape[0] - - assert E_inference.shape == torch.Size([nr_of_mols]) - assert F_inference.shape == (nr_of_atoms_per_batch, 3) # only one molecule - - torch.manual_seed(42) - model_training = NeuralNetworkPotentialFactory.generate_model( - use="training", - model_parameter=config["potential"], - training_parameter=training_parameter, - ) - - E_training = model_training.model.forward(nnp_input)["per_molecule_energy"] - F_training = -torch.autograd.grad( - E_training.sum(), nnp_input.positions, create_graph=True, retain_graph=True - )[0] - - # make sure that both agree on E and F - assert torch.allclose(E_inference, E_training, atol=1e-4) - assert torch.allclose(F_inference, F_training, atol=1e-4) - - -@pytest.mark.parametrize("model_name", _Implemented_NNPs.get_all_neural_network_names()) -def test_calculate_energies_and_forces_with_jax( - model_name, single_batch_with_batchsize_64 -): - """ - Test the calculation of energies and forces for a molecule. - """ - import torch - - # read default parameters - config = load_configs(f"{model_name.lower()}", "qm9") - - nnp_input = single_batch_with_batchsize_64.nnp_input - # test the backward pass through each of the models - nr_of_mols = nnp_input.atomic_subsystem_indices.unique().shape[0] - nr_of_atoms_per_batch = nnp_input.atomic_subsystem_indices.shape[0] - - # The inference_model fixture now returns a function that expects an environment - model = NeuralNetworkPotentialFactory.generate_model( - use="inference", - model_parameter=config["potential"], - simulation_environment="JAX", - ) - - nnp_input = nnp_input.as_jax_namedtuple() - - result = model(nnp_input)["per_molecule_energy"] - - from modelforge.utils.io import import_ - - jax = import_("jax") - - grad_fn = jax.grad(lambda pos: result.sum()) # Create a gradient function - forces = -grad_fn( - nnp_input.positions - ) # Evaluate gradient function and apply negative sign - assert result.shape == torch.Size([nr_of_mols]) # only one molecule - assert forces.shape == (nr_of_atoms_per_batch, 3) # only one molecule - - -def test_pairlist_logic(): - import torch - - # dummy data for illustration - positions = torch.tensor( - [ - [0.4933, 0.4460, 0.5762], - [0.2340, 0.2053, 0.5025], - [0.6566, 0.1263, 0.8792], - [0.1656, 0.0338, 0.6708], - [0.5696, 0.4790, 0.9622], - [0.3499, 0.4241, 0.8818], - [0.8400, 0.9389, 0.1888], - [0.4983, 0.0793, 0.8639], - [0.6605, 0.7567, 0.1938], - [0.7725, 0.9758, 0.7063], - ] - ) - molecule_indices = torch.tensor( - [0, 0, 0, 1, 1, 2, 2, 2, 3, 3] - ) # molecule index for each atom - - # generate index grid - n = len(molecule_indices) - i_indices, j_indices = torch.triu_indices(n, n, 1) - - # filter pairs to only keep those belonging to the same molecule - same_molecule_mask = molecule_indices[i_indices] == molecule_indices[j_indices] - - # Apply mask to get final pair indices - i_final_pairs = i_indices[same_molecule_mask] - j_final_pairs = j_indices[same_molecule_mask] - - # Concatenate to form final (2, n_pairs) tensor - final_pair_indices = torch.stack((i_final_pairs, j_final_pairs)) - - assert torch.allclose( - final_pair_indices, - torch.tensor([[0, 0, 1, 3, 5, 5, 6, 8], [1, 2, 2, 4, 6, 7, 7, 9]]), - ) - - # Create pair_coordinates tensor - pair_coordinates = positions[final_pair_indices.T] - pair_coordinates = pair_coordinates.view(-1, 2, 3) - - # Calculate distances - distances = (pair_coordinates[:, 0, :] - pair_coordinates[:, 1, :]).norm( - p=2, dim=-1 - ) - # Calculate distances - distances = (pair_coordinates[:, 0, :] - pair_coordinates[:, 1, :]).norm( - p=2, dim=-1 - ) - - # Define a cutoff - cutoff = 1.0 - - # Find pairs within the cutoff - in_cutoff = (distances <= cutoff).nonzero(as_tuple=False).squeeze() - - # Get the atom indices within the cutoff - atom_pairs_withing_cutoff = final_pair_indices[:, in_cutoff] - assert torch.allclose( - atom_pairs_withing_cutoff, - torch.tensor([[0, 0, 1, 3, 5, 5, 8], [1, 2, 2, 4, 6, 7, 9]]), - ) - - -def test_pairlist(): - from modelforge.potential.models import Pairlist, Neighborlist - import torch - - atomic_subsystem_indices = torch.tensor([0, 0, 0, 1, 1, 1]) - positions = torch.tensor( - [ - [0.0, 0.0, 0.0], - [1.0, 1.0, 1.0], - [2.0, 2.0, 2.0], - [3.0, 3.0, 3.0], - [4.0, 4.0, 4.0], - [5.0, 5.0, 5.0], - ] - ) - from openff.units import unit - - cutoff = 5.0 * unit.nanometer # no relevant cutoff - pairlist = Neighborlist(cutoff, only_unique_pairs=True) - r = pairlist(positions, atomic_subsystem_indices) - pair_indices = r.pair_indices - - # pairlist describes the pairs of interacting atoms within a batch - # that means for the pairlist provided below: - # pair1: pairlist[0][0] and pairlist[1][0], i.e. (0,1) - # pair2: pairlist[0][1] and pairlist[1][1], i.e. (0,2) - # pair3: pairlist[0][2] and pairlist[1][2], i.e. (1,2) - - assert torch.allclose( - pair_indices, torch.tensor([[0, 0, 1, 3, 3, 4], [1, 2, 2, 4, 5, 5]]) - ) - # NOTE: pairs are defined on axis=1 and not axis=0 - assert torch.allclose( - r.r_ij, - torch.tensor( - [ - [1.0, 1.0, 1.0], # pair1, [1.0, 1.0, 1.0] - [0.0, 0.0, 0.0] - [2.0, 2.0, 2.0], # pair2, [2.0, 2.0, 2.0] - [0.0, 0.0, 0.0] - [1.0, 1.0, 1.0], # pair3, [3.0, 3.0, 3.0] - [0.0, 0.0, 0.0] - [1.0, 1.0, 1.0], - [2.0, 2.0, 2.0], - [1.0, 1.0, 1.0], - ] - ), - ) - - # test with cutoff - cutoff = 2.0 * unit.nanometer - pairlist = Neighborlist(cutoff, only_unique_pairs=True) - r = pairlist(positions, atomic_subsystem_indices) - pair_indices = r.pair_indices - - assert torch.equal(pair_indices, torch.tensor([[0, 1, 3, 4], [1, 2, 4, 5]])) - # pairs that are excluded through cutoff: (0,2) and (3,5) - assert torch.equal( - r.r_ij, - torch.tensor( - [ - [1.0, 1.0, 1.0], - [1.0, 1.0, 1.0], - [1.0, 1.0, 1.0], - [1.0, 1.0, 1.0], - ] - ), - ) - - assert torch.allclose( - r.d_ij, torch.tensor([1.7321, 1.7321, 1.7321, 1.7321]), atol=1e-3 - ) - - # test with complete pairlist - cutoff = 2.0 * unit.nanometer - pairlist = Neighborlist(cutoff, only_unique_pairs=False) - r = pairlist(positions, atomic_subsystem_indices) - pair_indices = r.pair_indices - - print(pair_indices, flush=True) - assert torch.equal( - pair_indices, torch.tensor([[0, 1, 1, 2, 3, 4, 4, 5], [1, 0, 2, 1, 4, 3, 5, 4]]) - ) - - # make sure that Pairlist and Neighborlist behave the same for large cutoffs - cutoff = 10.0 * unit.nanometer - only_unique_pairs = False - neighborlist = Neighborlist(cutoff, only_unique_pairs=only_unique_pairs) - pairlist = Pairlist(only_unique_pairs=only_unique_pairs) - r = pairlist(positions, atomic_subsystem_indices) - pair_indices = r.pair_indices - r = neighborlist(positions, atomic_subsystem_indices) - neighbor_indices = r.pair_indices - - assert torch.equal(pair_indices, neighbor_indices) - - # make sure that they are the same also for non-redundant pairs - cutoff = 10.0 * unit.nanometer - only_unique_pairs = True - neighborlist = Neighborlist(cutoff, only_unique_pairs=only_unique_pairs) - pairlist = Pairlist(only_unique_pairs=only_unique_pairs) - r = pairlist(positions, atomic_subsystem_indices) - pair_indices = r.pair_indices - r = neighborlist(positions, atomic_subsystem_indices) - neighbor_indices = r.pair_indices - - assert torch.equal(pair_indices, neighbor_indices) - - # this should fail - cutoff = 2.0 * unit.nanometer - only_unique_pairs = True - neighborlist = Neighborlist(cutoff, only_unique_pairs=only_unique_pairs) - pairlist = Pairlist(only_unique_pairs=only_unique_pairs) - r = pairlist(positions, atomic_subsystem_indices) - pair_indices = r.pair_indices - r = neighborlist(positions, atomic_subsystem_indices) - neighbor_indices = r.pair_indices - - assert not pair_indices.shape == neighbor_indices.shape - - -def test_pairlist_precomputation(): - from modelforge.potential.models import Pairlist - import torch - import numpy as np - - atomic_subsystem_indices = torch.tensor([0, 0, 0]) - - pairlist = Pairlist() - - pairs, nr_pairs = pairlist.construct_initial_pairlist_using_numpy( - atomic_subsystem_indices.to("cpu") - ) - - assert pairs.shape == (2, 6) - assert nr_pairs[0] == 6 - - # 3 molecules, 3 atoms each - atomic_subsystem_indices = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2]) - pairs, nr_pairs = pairlist.construct_initial_pairlist_using_numpy( - atomic_subsystem_indices.to("cpu") - ) - - assert pairs.shape == (2, 18) - assert np.all(nr_pairs == [6, 6, 6]) - - # 3 molecules, 3,4, and 5 atoms each - atomic_subsystem_indices = torch.tensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]) - pairs, nr_pairs = pairlist.construct_initial_pairlist_using_numpy( - atomic_subsystem_indices.to("cpu") - ) - - assert pairs.shape == (2, 38) - assert np.all(nr_pairs == [6, 12, 20]) - - -def test_pairlist_on_dataset(): - # Set up a dataset - from modelforge.dataset.dataset import DataModule - from modelforge.dataset.utils import FirstComeFirstServeSplittingStrategy - - # prepare reference value - dataset = DataModule( - name="QM9", - batch_size=1, - version_select="nc_1000_v0", - splitting_strategy=FirstComeFirstServeSplittingStrategy(), - remove_self_energies=True, - regression_ase=False, - ) - dataset.prepare_data() - dataset.setup() - # -------------------------------# - # -------------------------------# - # get methane input - batch = next(iter(dataset.train_dataloader(shuffle=False))).nnp_input - import torch - - # make sure that the pairlist of methane is correct (single molecule) - assert torch.equal( - batch.pair_list, - torch.tensor( - [ - [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4], - [1, 2, 3, 4, 0, 2, 3, 4, 0, 1, 3, 4, 0, 1, 2, 4, 0, 1, 2, 3], - ] - ), - ) - - # test that the pairlist of 2 molecules is correct (which can then be expected also to be true for N molecules) - dataset = DataModule( - name="QM9", - batch_size=2, - version_select="nc_1000_v0", - splitting_strategy=FirstComeFirstServeSplittingStrategy(), - remove_self_energies=True, - regression_ase=False, - ) - dataset.prepare_data() - dataset.setup() - # -------------------------------# - # -------------------------------# - # get methane input - batch = next(iter(dataset.train_dataloader(shuffle=False))).nnp_input - - assert torch.equal( - batch.pair_list, - torch.tensor( - [ - [ - 0, - 0, - 0, - 0, - 1, - 1, - 1, - 1, - 2, - 2, - 2, - 2, - 3, - 3, - 3, - 3, - 4, - 4, - 4, - 4, - 5, - 5, - 5, - 6, - 6, - 6, - 7, - 7, - 7, - 8, - 8, - 8, - ], - [ - 1, - 2, - 3, - 4, - 0, - 2, - 3, - 4, - 0, - 1, - 3, - 4, - 0, - 1, - 2, - 4, - 0, - 1, - 2, - 3, - 6, - 7, - 8, - 5, - 7, - 8, - 5, - 6, - 8, - 5, - 6, - 7, - ], - ] - ), - ) - - # check that the pairlist maximum value for i is the number of atoms in the batch - assert ( - int(batch.pair_list[0][-1].item()) + 1 == 8 + 1 == len(batch.atomic_numbers) - ) # +1 because of 0-based indexing - - -@pytest.mark.parametrize("model_name", _Implemented_NNPs.get_all_neural_network_names()) -def test_casting(model_name, single_batch_with_batchsize_64): - # test dtype casting - import torch - - batch = single_batch_with_batchsize_64 - batch_ = batch.to(dtype=torch.float64) - assert batch_.nnp_input.positions.dtype == torch.float64 - batch_ = batch_.to(dtype=torch.float32) - assert batch_.nnp_input.positions.dtype == torch.float32 - - nnp_input = batch.nnp_input.to(dtype=torch.float64) - assert nnp_input.positions.dtype == torch.float64 - nnp_input = batch.nnp_input.to(dtype=torch.float32) - assert nnp_input.positions.dtype == torch.float32 - nnp_input = batch.metadata.to(dtype=torch.float64) - - # cast input and model to torch.float64 - # read default parameters - config = load_configs(f"{model_name.lower()}", "qm9") - - model = NeuralNetworkPotentialFactory.generate_model( - use="inference", - simulation_environment="PyTorch", - model_parameter=config["potential"], - ) - model = model.to(dtype=torch.float64) - nnp_input = batch.nnp_input.to(dtype=torch.float64) - - model(nnp_input) - - # cast input and model to torch.float64 - model = NeuralNetworkPotentialFactory.generate_model( - use="inference", - simulation_environment="PyTorch", - model_parameter=config["potential"], - ) - model = model.to(dtype=torch.float32) - nnp_input = batch.nnp_input.to(dtype=torch.float32) - - model(nnp_input) - - -@pytest.mark.parametrize("model_name", _Implemented_NNPs.get_all_neural_network_names()) -@pytest.mark.parametrize("simulation_environment", ["PyTorch"]) -def test_equivariant_energies_and_forces( - model_name, - simulation_environment, - single_batch_with_batchsize_64, - equivariance_utils, -): - """ - Test the calculation of energies and forces for a molecule. - NOTE: test will be adapted once we have a trained model. - """ - import torch - from dataclasses import replace - - # load default parameters - config = load_configs(f"{model_name}", "qm9") - - model = NeuralNetworkPotentialFactory.generate_model( - use="inference", - simulation_environment=simulation_environment, - model_parameter=config["potential"], - ) - - # define the symmetry operations - translation, rotation, reflection = equivariance_utils - # define the tolerance - atol = 1e-3 - nnp_input = single_batch_with_batchsize_64.nnp_input - - # initialize the models - model = model.to(dtype=torch.float64) - - # ------------------- # - # start the test - # reference values - nnp_input = single_batch_with_batchsize_64.nnp_input.to(dtype=torch.float64) - reference_result = model(nnp_input)["per_molecule_energy"].to(dtype=torch.float64) - reference_forces = -torch.autograd.grad( - reference_result.sum(), - nnp_input.positions, - )[0] - - # translation test - translation_nnp_input = replace(nnp_input) - translation_nnp_input.positions = translation(translation_nnp_input.positions) - translation_result = model(translation_nnp_input)["per_molecule_energy"] - assert torch.allclose( - translation_result, - reference_result, - atol=atol, - ) - - translation_forces = -torch.autograd.grad( - translation_result.sum(), - translation_nnp_input.positions, - )[0] - - for t, r in zip(translation_forces, reference_forces): - if not torch.allclose(t, r, atol=atol): - print(t, r) - - assert torch.allclose( - translation_forces, - reference_forces, - atol=atol, - ) - - # rotation test - rotation_input_data = replace(nnp_input) - rotation_input_data.positions = rotation(rotation_input_data.positions) - rotation_result = model(rotation_input_data)["per_molecule_energy"] - - for t, r in zip(rotation_result, reference_result): - if not torch.allclose(t, r, atol=atol): - print(t, r) - - assert torch.allclose( - rotation_result, - reference_result, - atol=atol, - ) - - rotation_forces = -torch.autograd.grad( - rotation_result.sum(), - rotation_input_data.positions, - create_graph=True, - retain_graph=True, - )[0] - - rotate_reference = rotation(reference_forces) - assert torch.allclose( - rotation_forces, - rotate_reference, - atol=atol, - ) - - # reflection test - reflection_input_data = replace(nnp_input) - reflection_input_data.positions = reflection(reflection_input_data.positions) - reflection_result = model(reflection_input_data)["per_molecule_energy"] - reflection_forces = -torch.autograd.grad( - reflection_result.sum(), - reflection_input_data.positions, - create_graph=True, - retain_graph=True, - )[0] - for t, r in zip(reflection_result, reference_result): - if not torch.allclose(t, r, atol=atol): - print(t, r) - - assert torch.allclose( - reflection_result, - reference_result, - atol=atol, - ) - - assert torch.allclose( - reflection_forces, - reflection(reference_forces), - atol=atol, - ) - - -def test_pairlist_calculate_r_ij_and_d_ij(): - # Define inputs - from modelforge.potential.models import Neighborlist - import torch - - positions = torch.tensor( - [[0.0, 0.0, 0.0], [2.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 4.0, 1.0]] - ) - atomic_subsystem_indices = torch.tensor([0, 0, 1, 1]) - from openff.units import unit - - cutoff = 3.0 * unit.nanometer - - # Create Pairlist instance - # --------------------------- # - # Only unique pairs - pairlist = Neighborlist(cutoff, only_unique_pairs=True) - pair_indices = pairlist.enumerate_all_pairs(atomic_subsystem_indices) - - # Calculate r_ij and d_ij - r_ij = pairlist.calculate_r_ij(pair_indices, positions) - d_ij = pairlist.calculate_d_ij(r_ij) - - # Check if the calculated r_ij and d_ij are correct - expected_r_ij = torch.tensor([[2.0, 0.0, 0.0], [0.0, 2.0, 1.0]]) - expected_d_ij = torch.tensor([[2.0000], [2.2361]]) - - assert torch.allclose(r_ij, expected_r_ij, atol=1e-3) - assert torch.allclose(d_ij, expected_d_ij, atol=1e-3) - - normalized_r_ij = r_ij / d_ij - expected_normalized_r_ij = torch.tensor( - [[1.0000, 0.0000, 0.0000], [0.0000, 0.8944, 0.4472]] - ) - assert torch.allclose(expected_normalized_r_ij, normalized_r_ij, atol=1e-3) - - # --------------------------- # - # ALL pairs - pairlist = Neighborlist(cutoff, only_unique_pairs=False) - pair_indices = pairlist.enumerate_all_pairs(atomic_subsystem_indices) - - # Calculate r_ij and d_ij - r_ij = pairlist.calculate_r_ij(pair_indices, positions) - d_ij = pairlist.calculate_d_ij(r_ij) - - # Check if the calculated r_ij and d_ij are correct - expected_r_ij = torch.tensor( - [[2.0, 0.0, 0.0], [-2.0, 0.0, 0.0], [0.0, 2.0, 1.0], [0.0, -2.0, -1.0]] - ) - expected_d_ij = torch.tensor([[2.0000], [2.0000], [2.2361], [2.2361]]) - - assert torch.allclose(r_ij, expected_r_ij, atol=1e-3) - assert torch.allclose(d_ij, expected_d_ij, atol=1e-3) diff --git a/modelforge/tests/test_nn.py b/modelforge/tests/test_nn.py index c0399488..c7fd1de1 100644 --- a/modelforge/tests/test_nn.py +++ b/modelforge/tests/test_nn.py @@ -1,34 +1,71 @@ -def test_radial_symmetry_function(): +from .test_potentials import load_configs_into_pydantic_models +import pytest - from modelforge.potential.utils import SchnetRadialBasisFunction, CosineCutoff - import torch - from openff.units import unit - # set cutoff and radial symmetry function - cutoff = CosineCutoff(cutoff=unit.Quantity(5.0, unit.angstrom)) - rbf_expension = SchnetRadialBasisFunction( - number_of_radial_basis_functions=18, - max_distance=unit.Quantity(5.0, unit.angstrom), +@pytest.fixture(scope="session") +def prep_temp_dir(tmp_path_factory): + fn = tmp_path_factory.mktemp("test_nn_temp") + return fn + + +def test_embedding(single_batch_with_batchsize, prep_temp_dir): + # test the input featurization, including: + # - nuclear charge embedding + # - total charge mixing + + import torch # noqa: F401 + + batch = single_batch_with_batchsize( + batch_size=64, dataset_name="QM9", local_cache_dir=str(prep_temp_dir) ) - # calculate expension and cutoff - d_ij = torch.tensor( - [[0.0], [0.1], [0.2], [0.3], [0.4], [0.5]] - ) # distances have the dimensions [nr_of_pairs, 1] (because displacement vectors have the dimensions [nr_of_pairs, 3]) + nnp_input = batch.nnp_input + model_name = "SchNet" + # read default parameters and extract featurization + config = load_configs_into_pydantic_models(f"{model_name.lower()}", "qm9") + featurization_config = config["potential"].core_parameter.featurization.model_dump() - f_ij_cutoff = cutoff(d_ij) - f_ij = rbf_expension(d_ij) - vs = f_ij * f_ij_cutoff + # featurize the atomic input (default is only atomic number embedding) + from modelforge.potential import FeaturizeInput - # make sure that this matches the output of SchNETRepresentation - from modelforge.potential.schnet import SchNETRepresentation + featurize_input_module = FeaturizeInput(featurization_config) - rep = SchNETRepresentation( - radial_cutoff=5 * unit.angstrom, - number_of_radial_basis_functions=18, + # mixing module should be the identity operation since only atomic number + # embedding is used + mixing_module = featurize_input_module.mixing + assert mixing_module.__module__ == "torch.nn.modules.linear" + mixing_module_name = str(mixing_module) + + # only atomic number embedded + assert "atomic_number" in featurize_input_module.registered_embedding_operations + assert len(featurize_input_module.registered_embedding_operations) == 1 + # no mixing + assert "Identity()" in mixing_module_name + + # add total charge to the input + featurization_config["properties_to_featurize"].append("per_system_total_charge") + featurize_input_module = FeaturizeInput(featurization_config) + + # only atomic number embedded + assert "atomic_number" in featurize_input_module.registered_embedding_operations + assert len(featurize_input_module.registered_embedding_operations) == 1 + # total charge is added to feature vector + assert ( + "per_system_total_charge" + in featurize_input_module.registered_appended_properties ) + assert len(featurize_input_module.registered_appended_properties) == 1 + + mixing_module = featurize_input_module.mixing + assert ( + mixing_module.__module__ == "modelforge.potential.utils" + ) # this is were Dense lives + mixing_module_name = str(mixing_module) - representation = rep(d_ij) - f_ij_cutoff = representation["f_ij"] * representation["f_cutoff"] + assert "Dense" in mixing_module_name - assert torch.allclose(vs, f_ij_cutoff) + # make a forward pass, embedd nuclear charges and add total charge (is expanded from per-molecule to per-atom property). Mix the properties then. + out = featurize_input_module(nnp_input) + assert out.shape == torch.Size( + [557, 32] + ) # nr_of_atoms, nr_of_per_atom_features (the total charge is mixed in) diff --git a/modelforge/tests/test_painn.py b/modelforge/tests/test_painn.py index 8b8c75db..ced91329 100644 --- a/modelforge/tests/test_painn.py +++ b/modelforge/tests/test_painn.py @@ -1,23 +1,48 @@ -import torch -from modelforge.potential.painn import PaiNN +import pytest +from modelforge.potential import NeuralNetworkPotentialFactory -def test_forward(single_batch_with_batchsize_64): - """Test initialization of the PaiNN neural network potential.""" - # read default parameters - from modelforge.tests.test_models import load_configs - # read default parameters - config = load_configs("painn", "qm9") +@pytest.fixture(scope="session") +def prep_temp_dir(tmp_path_factory): + fn = tmp_path_factory.mktemp("test_painn_temp") + return fn - painn = PaiNN( - **config["potential"]["core_parameter"], - postprocessing_parameter=config["potential"]["postprocessing_parameter"], + +def setup_painn_model(potential_seed: int): + from modelforge.tests.test_potentials import load_configs_into_pydantic_models + + # read default parameters + config = load_configs_into_pydantic_models("painn", "qm9") + # override defaults to match reference implementation in spk + config[ + "potential" + ].core_parameter.featurization.atomic_number.maximum_atomic_number = 100 + config[ + "potential" + ].core_parameter.featurization.atomic_number.number_of_per_atom_features = 8 + config["potential"].core_parameter.number_of_radial_basis_functions = 5 + + trainer_painn = NeuralNetworkPotentialFactory.generate_trainer( + potential_parameter=config["potential"], + training_parameter=config["training"], + dataset_parameter=config["dataset"], + runtime_parameter=config["runtime"], + potential_seed=potential_seed, + ).lightning_module.potential + return trainer_painn + + +def test_forward(single_batch_with_batchsize, prep_temp_dir): + """Test initialization of the PaiNN neural network potential.""" + trainer_painn = setup_painn_model(42) + assert trainer_painn is not None, "PaiNN model should be initialized." + batch = single_batch_with_batchsize( + batch_size=64, dataset_name="QM9", local_cache_dir=str(prep_temp_dir) ) - assert painn is not None, "PaiNN model should be initialized." - nnp_input = single_batch_with_batchsize_64.nnp_input.to(dtype=torch.float32) - energy = painn(nnp_input)["per_molecule_energy"] + nnp_input = batch.to_dtype(dtype=torch.float32).nnp_input + energy = trainer_painn(nnp_input)["per_system_energy"] nr_of_mols = nnp_input.atomic_subsystem_indices.unique().shape[0] assert ( @@ -25,201 +50,53 @@ def test_forward(single_batch_with_batchsize_64): ) # Assuming energy is calculated per sample in the batch -def test_equivariance(single_batch_with_batchsize_64): - from modelforge.potential.painn import PaiNN - from dataclasses import replace - import torch - - from modelforge.tests.test_models import load_configs - - # read default parameters - config = load_configs("painn", "qm9") - - # define a rotation matrix in 3D that rotates by 90 degrees around the z-axis - # (clockwise when looking along the z-axis towards the origin) - rotation_matrix = torch.tensor( - [[0.0, 1.0, 0.0], [-1.0, 0.0, 0.0], [0.0, 0.0, 1.0]], dtype=torch.float64 - ) - - painn = PaiNN( - **config["potential"]["core_parameter"], - postprocessing_parameter=config["potential"]["postprocessing_parameter"], - ).double() - - methane_input = single_batch_with_batchsize_64.nnp_input.to(dtype=torch.float64) - perturbed_methane_input = replace(methane_input) - perturbed_methane_input.positions = torch.matmul( - methane_input.positions, rotation_matrix - ) - - # prepare reference and perturbed inputs - pairlist_output = painn.input_preparation.prepare_inputs(methane_input) - reference_prepared_input = painn.core_module._model_specific_input_preparation( - methane_input, pairlist_output - ) - - reference_d_ij = reference_prepared_input.d_ij - reference_r_ij = reference_prepared_input.r_ij - reference_dir_ij = reference_r_ij / reference_d_ij - reference_f_ij = ( - painn.core_module.representation_module.radial_symmetry_function_module( - reference_d_ij - ) - ) - - pairlist_output = painn.input_preparation.prepare_inputs(perturbed_methane_input) - perturbed_prepared_input = painn.core_module._model_specific_input_preparation( - perturbed_methane_input, pairlist_output - ) - - perturbed_d_ij = perturbed_prepared_input.d_ij - perturbed_r_ij = perturbed_prepared_input.r_ij - perturbed_dir_ij = perturbed_r_ij / perturbed_d_ij - perturbed_f_ij = ( - painn.core_module.representation_module.radial_symmetry_function_module( - perturbed_d_ij - ) - ) - - # check that the invariant properties are preserved - # d_ij is the distance between atom i and j - # f_ij is the radial basis function of d_ij - assert torch.allclose(reference_d_ij, perturbed_d_ij) - assert torch.allclose(reference_f_ij, perturbed_f_ij) - - # what shoudl not be invariant is the direction - assert not torch.allclose(reference_dir_ij, perturbed_dir_ij) - - # Check for equivariance - # rotate the reference dir_ij - rotated_reference_dir_ij = torch.matmul(reference_dir_ij, rotation_matrix) - # Compare the rotated original dir_ij with the dir_ij from rotated positions - assert torch.allclose(rotated_reference_dir_ij, perturbed_dir_ij) - - # Test that the interaction block is equivariant - # First we test the transformed inputs - reference_tranformed_inputs = painn.core_module.representation_module( - reference_prepared_input - ) - perturbed_tranformed_inputs = painn.core_module.representation_module( - perturbed_prepared_input - ) - - assert torch.allclose( - reference_tranformed_inputs["q"], perturbed_tranformed_inputs["q"] - ) - assert torch.allclose( - reference_tranformed_inputs["mu"], perturbed_tranformed_inputs["mu"] - ) - - painn_interaction = painn.core_module.interaction_modules[0] - - reference_r = painn_interaction( - reference_tranformed_inputs["q"], - reference_tranformed_inputs["mu"], - reference_tranformed_inputs["filters"][0], - reference_dir_ij, - reference_prepared_input.pair_indices, - ) - - perturbed_r = painn_interaction( - perturbed_tranformed_inputs["q"], - perturbed_tranformed_inputs["mu"], - reference_tranformed_inputs["filters"][0], - perturbed_dir_ij, - perturbed_prepared_input.pair_indices, - ) - - perturbed_q, perturbed_mu = perturbed_r - reference_q, reference_mu = reference_r - - # mu is different, q is invariant - assert torch.allclose(reference_q, perturbed_q) - assert not torch.allclose(reference_mu, perturbed_mu) - - mixed_reference_q, mixed_reference_mu = painn.core_module.mixing_modules[0]( - reference_q, reference_mu - ) - mixed_perturbed_q, mixed_perturbed_mu = painn.core_module.mixing_modules[0]( - perturbed_q, perturbed_mu - ) - - # q is a scalar property and invariant - assert torch.allclose(mixed_reference_q, mixed_perturbed_q, atol=1e-2) - # mu is a vector property and should not be invariant - assert not torch.allclose(mixed_reference_mu, mixed_perturbed_mu) - - import torch from modelforge.tests.test_schnet import setup_single_methane_input -def test_compare_representation(): +def test_compare_implementation_against_reference_implementation(): # ---------------------------------------- # # setup the PaiNN model # ---------------------------------------- # - from openff.units import unit from .precalculated_values import load_precalculated_painn_results - from modelforge.tests.test_models import load_configs - # read default parameters - config = load_configs("painn", "qm9") - - torch.manual_seed(1234) - - # override defaults to match reference implementation in spk - config["potential"]["core_parameter"]["max_Z"] = 100 - config["potential"]["core_parameter"]["number_of_atom_features"] = 8 - config["potential"]["core_parameter"]["number_of_radial_basis_functions"] = 5 - - # initialize model - model = PaiNN( - **config["potential"]["core_parameter"], - postprocessing_parameter=config["potential"]["postprocessing_parameter"], - ).to(torch.float64) + potential = setup_painn_model(potential_seed=1234).double() # ------------------------------------ # # set up the input for the Painn model input = setup_single_methane_input() - spk_input = input["spk_methane_input"] - mf_nnp_input = input["modelforge_methane_input"] - - model.input_preparation._input_checks(mf_nnp_input) - pairlist_output = model.input_preparation.prepare_inputs(mf_nnp_input) - prepared_input = model.core_module._model_specific_input_preparation( - mf_nnp_input, pairlist_output - ) + nnp_input = input["modelforge_methane_input"] # ---------------------------------------- # # test forward pass # ---------------------------------------- # - # reset filter parameters torch.manual_seed(1234) - model.core_module.representation_module.filter_net.reset_parameters() + potential.core_network.representation_module.filter_net.reset_parameters() + + calculated_results = potential.compute_core_network_output(nnp_input) - calculated_results = model.core_module.forward(prepared_input, pairlist_output) reference_results = load_precalculated_painn_results() # check that the scalar and vector representations are the same # start with scalar representation assert ( reference_results["scalar_representation"].shape - == calculated_results["q"].shape + == calculated_results["per_atom_scalar_representation"].shape ) scalar_spk = reference_results["scalar_representation"].double() - scalar_mf = calculated_results["q"].double() + scalar_mf = calculated_results["per_atom_scalar_representation"].double() assert torch.allclose(scalar_spk, scalar_mf, atol=1e-4) # check vector representation assert ( reference_results["vector_representation"].shape - == calculated_results["mu"].shape + == calculated_results["per_atom_vector_representation"].shape ) assert torch.allclose( reference_results["vector_representation"].double(), - calculated_results["mu"].double(), + calculated_results["per_atom_vector_representation"].double(), atol=1e-4, ) diff --git a/modelforge/tests/test_pairlist.py b/modelforge/tests/test_pairlist.py new file mode 100644 index 00000000..e37910b2 --- /dev/null +++ b/modelforge/tests/test_pairlist.py @@ -0,0 +1,747 @@ +def test_pairlist_logic(): + import torch + + # dummy data for illustration + positions = torch.tensor( + [ + [0.4933, 0.4460, 0.5762], + [0.2340, 0.2053, 0.5025], + [0.6566, 0.1263, 0.8792], + [0.1656, 0.0338, 0.6708], + [0.5696, 0.4790, 0.9622], + [0.3499, 0.4241, 0.8818], + [0.8400, 0.9389, 0.1888], + [0.4983, 0.0793, 0.8639], + [0.6605, 0.7567, 0.1938], + [0.7725, 0.9758, 0.7063], + ] + ) + molecule_indices = torch.tensor( + [0, 0, 0, 1, 1, 2, 2, 2, 3, 3] + ) # molecule index for each atom + + # generate index grid + n = len(molecule_indices) + i_indices, j_indices = torch.triu_indices(n, n, 1) + + # filter pairs to only keep those belonging to the same molecule + same_molecule_mask = molecule_indices[i_indices] == molecule_indices[j_indices] + + # Apply mask to get final pair indices + i_final_pairs = i_indices[same_molecule_mask] + j_final_pairs = j_indices[same_molecule_mask] + + # Concatenate to form final (2, n_pairs) tensor + final_pair_indices = torch.stack((i_final_pairs, j_final_pairs)) + + assert torch.allclose( + final_pair_indices, + torch.tensor([[0, 0, 1, 3, 5, 5, 6, 8], [1, 2, 2, 4, 6, 7, 7, 9]]), + ) + + # Create pair_coordinates tensor + pair_coordinates = positions[final_pair_indices.T] + pair_coordinates = pair_coordinates.view(-1, 2, 3) + + # Calculate distances + distances = (pair_coordinates[:, 0, :] - pair_coordinates[:, 1, :]).norm( + p=2, dim=-1 + ) + # Calculate distances + distances = (pair_coordinates[:, 0, :] - pair_coordinates[:, 1, :]).norm( + p=2, dim=-1 + ) + + # Define a cutoff + cutoff = 1.0 + + # Find pairs within the cutoff + in_cutoff = (distances <= cutoff).nonzero().squeeze() + + # Get the atom indices within the cutoff + atom_pairs_withing_cutoff = final_pair_indices[:, in_cutoff] + assert torch.allclose( + atom_pairs_withing_cutoff, + torch.tensor([[0, 0, 1, 3, 5, 5, 8], [1, 2, 2, 4, 6, 7, 9]]), + ) + + +def test_pairlist(): + import torch + from collections import namedtuple + from modelforge.potential.neighbors import NeighborListForTraining, Pairlist + + TestInput = namedtuple( + "TestInput", ["positions", "atomic_subsystem_indices", "pair_list"] + ) + + atomic_subsystem_indices = torch.tensor([0, 0, 0, 1, 1, 1]) + positions = torch.tensor( + [ + [0.0, 0.0, 0.0], + [1.0, 1.0, 1.0], + [2.0, 2.0, 2.0], + [3.0, 3.0, 3.0], + [4.0, 4.0, 4.0], + [5.0, 5.0, 5.0], + ] + ) + from openff.units import unit + + cutoff = unit.Quantity(5.0, unit.nanometer).to(unit.nanometer).m + nlist = NeighborListForTraining(cutoff, only_unique_pairs=True) + + r = nlist(TestInput(positions, atomic_subsystem_indices, None)) + pair_indices = r.pair_indices + + # pairlist describes the pairs of interacting atoms within a batch + # that means for the pairlist provided below: + # pair1: pairlist[0][0] and pairlist[1][0], i.e. (0,1) + # pair2: pairlist[0][1] and pairlist[1][1], i.e. (0,2) + # pair3: pairlist[0][2] and pairlist[1][2], i.e. (1,2) + + assert torch.allclose( + pair_indices, torch.tensor([[0, 0, 1, 3, 3, 4], [1, 2, 2, 4, 5, 5]]) + ) + # NOTE: pairs are defined on axis=1 and not axis=0 + assert torch.allclose( + r.r_ij, + torch.tensor( + [ + [1.0, 1.0, 1.0], # pair1, [1.0, 1.0, 1.0] - [0.0, 0.0, 0.0] + [2.0, 2.0, 2.0], # pair2, [2.0, 2.0, 2.0] - [0.0, 0.0, 0.0] + [1.0, 1.0, 1.0], # pair3, [3.0, 3.0, 3.0] - [0.0, 0.0, 0.0] + [1.0, 1.0, 1.0], + [2.0, 2.0, 2.0], + [1.0, 1.0, 1.0], + ] + ), + ) + + # test with cutoff + cutoff = unit.Quantity(2.0, unit.nanometer).to(unit.nanometer).m + nlist = NeighborListForTraining(cutoff, only_unique_pairs=True) + r = nlist(TestInput(positions, atomic_subsystem_indices, None)) + pair_indices = r.pair_indices + + assert torch.equal(pair_indices, torch.tensor([[0, 1, 3, 4], [1, 2, 4, 5]])) + # pairs that are excluded through cutoff: (0,2) and (3,5) + assert torch.equal( + r.r_ij, + torch.tensor( + [ + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + ] + ), + ) + + assert torch.allclose( + r.d_ij, torch.tensor([1.7321, 1.7321, 1.7321, 1.7321]), atol=1e-3 + ) + + # -------------------------------- # + # test with complete pairlist + cutoff = unit.Quantity(2.0, unit.nanometer).to(unit.nanometer).m + neigborlist = NeighborListForTraining(cutoff, only_unique_pairs=False) + r = neigborlist(TestInput(positions, atomic_subsystem_indices, None)) + pair_indices = r.pair_indices + + assert torch.equal( + pair_indices, torch.tensor([[0, 1, 1, 2, 3, 4, 4, 5], [1, 0, 2, 1, 4, 3, 5, 4]]) + ) + + # -------------------------------- # + # make sure that Pairlist and Neighborlist behave the same for large cutoffs + cutoff = unit.Quantity(10.0, unit.nanometer).to(unit.nanometer).m + only_unique_pairs = False + neighborlist = NeighborListForTraining(cutoff, only_unique_pairs=only_unique_pairs) + pairlist = Pairlist(only_unique_pairs=only_unique_pairs) + r = pairlist(positions, atomic_subsystem_indices) + pair_indices = r.pair_indices + r = neighborlist(TestInput(positions, atomic_subsystem_indices, None)) + neighbor_indices = r.pair_indices + + assert torch.equal(pair_indices, neighbor_indices) + + # -------------------------------- # + # make sure that they are the same also for non-redundant pairs + cutoff = unit.Quantity(10.0, unit.nanometer).to(unit.nanometer).m + only_unique_pairs = True + neighborlist = NeighborListForTraining(cutoff, only_unique_pairs=only_unique_pairs) + pairlist = Pairlist(only_unique_pairs=only_unique_pairs) + r = pairlist(positions, atomic_subsystem_indices) + pair_indices = r.pair_indices + r = neighborlist(TestInput(positions, atomic_subsystem_indices, None)) + neighbor_indices = r.pair_indices + + assert torch.equal(pair_indices, neighbor_indices) + + # -------------------------------- # + # this should fail + cutoff = unit.Quantity(2.0, unit.nanometer).to(unit.nanometer).m + only_unique_pairs = True + neighborlist = NeighborListForTraining(cutoff, only_unique_pairs=only_unique_pairs) + pairlist = Pairlist(only_unique_pairs=only_unique_pairs) + r = pairlist(positions, atomic_subsystem_indices) + pair_indices = r.pair_indices + r = neighborlist(TestInput(positions, atomic_subsystem_indices, None)) + neighbor_indices = r.pair_indices + + assert not pair_indices.shape == neighbor_indices.shape + + +def test_pairlist_precomputation(): + import numpy as np + import torch + + from modelforge.potential.neighbors import Pairlist + + atomic_subsystem_indices = torch.tensor([0, 0, 0]) + + pairlist = Pairlist() + + pairs, nr_pairs = pairlist.construct_initial_pairlist_using_numpy( + atomic_subsystem_indices.to("cpu") + ) + + assert pairs.shape == (2, 6) + assert nr_pairs[0] == 6 + + # 3 molecules, 3 atoms each + atomic_subsystem_indices = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2]) + pairs, nr_pairs = pairlist.construct_initial_pairlist_using_numpy( + atomic_subsystem_indices.to("cpu") + ) + + assert pairs.shape == (2, 18) + assert np.all(nr_pairs == [6, 6, 6]) + + # 3 molecules, 3,4, and 5 atoms each + atomic_subsystem_indices = torch.tensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]) + pairs, nr_pairs = pairlist.construct_initial_pairlist_using_numpy( + atomic_subsystem_indices.to("cpu") + ) + + assert pairs.shape == (2, 38) + assert np.all(nr_pairs == [6, 12, 20]) + + +def test_pairlist_on_dataset(): + # Set up a dataset + from modelforge.dataset.dataset import DataModule + from modelforge.dataset.utils import FirstComeFirstServeSplittingStrategy + + # prepare reference value + dataset = DataModule( + name="QM9", + batch_size=1, + version_select="nc_1000_v0", + splitting_strategy=FirstComeFirstServeSplittingStrategy(), + remove_self_energies=True, + regression_ase=False, + ) + dataset.prepare_data() + dataset.setup() + # -------------------------------# + # -------------------------------# + # get methane input + batch = next(iter(dataset.train_dataloader(shuffle=False))).nnp_input + import torch + + # make sure that the pairlist of methane is correct (single molecule) + assert torch.equal( + batch.pair_list, + torch.tensor( + [ + [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4], + [1, 2, 3, 4, 0, 2, 3, 4, 0, 1, 3, 4, 0, 1, 2, 4, 0, 1, 2, 3], + ] + ), + ) + + # test that the pairlist of 2 molecules is correct (which can then be expected also to be true for N molecules) + dataset = DataModule( + name="QM9", + batch_size=2, + version_select="nc_1000_v0", + splitting_strategy=FirstComeFirstServeSplittingStrategy(), + remove_self_energies=True, + regression_ase=False, + ) + dataset.prepare_data() + dataset.setup() + # -------------------------------# + # -------------------------------# + # get methane input + batch = next(iter(dataset.train_dataloader(shuffle=False))).nnp_input + + assert torch.equal( + batch.pair_list, + torch.tensor( + [ + [ + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 2, + 2, + 2, + 2, + 3, + 3, + 3, + 3, + 4, + 4, + 4, + 4, + 5, + 5, + 5, + 6, + 6, + 6, + 7, + 7, + 7, + 8, + 8, + 8, + ], + [ + 1, + 2, + 3, + 4, + 0, + 2, + 3, + 4, + 0, + 1, + 3, + 4, + 0, + 1, + 2, + 4, + 0, + 1, + 2, + 3, + 6, + 7, + 8, + 5, + 7, + 8, + 5, + 6, + 8, + 5, + 6, + 7, + ], + ] + ), + ) + + # check that the pairlist maximum value for i is the number of atoms in the batch + assert ( + int(batch.pair_list[0][-1].item()) + 1 == 8 + 1 == len(batch.atomic_numbers) + ) # +1 because of 0-based indexing + + +def test_displacement_function(): + """Test that OrthogonalDisplacementFunction behaves as expected, including toggling periodicity""" + import torch + + from modelforge.potential.neighbors import OrthogonalDisplacementFunction + + displacement_function = OrthogonalDisplacementFunction() + + box_vectors = torch.tensor( + [[10, 0, 0], [0, 10, 0], [0, 0, 10]], dtype=torch.float32 + ) + + coords1 = torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [0, 0, 0], + [1.0, 1.0, 1.0], + [3.0, 3.0, 3.0], + [8.5, 8.5, 8.5], + ], + dtype=torch.float32, + ) + + coords2 = torch.tensor( + [ + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [1.0, 1.0, 1.0], + [0, 0, 0], + [8.5, 8.5, 8.5], + [3.0, 3.0, 3.0], + ], + dtype=torch.float32, + ) + r_ij, d_ij = displacement_function(coords1, coords1, box_vectors, is_periodic=True) + + assert torch.allclose(r_ij, torch.zeros_like(r_ij)) + assert torch.allclose(d_ij, torch.zeros_like(d_ij)) + + r_ij, d_ij = displacement_function(coords1, coords2, box_vectors, is_periodic=True) + + assert torch.allclose( + r_ij, + torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [-1.0, -1.0, -1.0], + [1.0, 1.0, 1.0], + [4.5, 4.5, 4.5], + [-4.5, -4.5, -4.5], + ], + dtype=r_ij.dtype, + ), + ) + + assert torch.allclose( + d_ij, + torch.tensor( + [[1.0], [1.0], [1.0], [1.7321], [1.7321], [7.7942], [7.7942]], + dtype=d_ij.dtype, + ), + atol=1e-4, + ) + # make sure the function works if the box is not periodic + displacement_function = OrthogonalDisplacementFunction() + r_ij, d_ij = displacement_function(coords1, coords1, box_vectors, is_periodic=False) + + assert torch.allclose(r_ij, torch.zeros_like(r_ij)) + assert torch.allclose(d_ij, torch.zeros_like(d_ij)) + + r_ij, d_ij = displacement_function(coords1, coords2, box_vectors, is_periodic=False) + + # since the + assert torch.allclose(r_ij, coords1 - coords2) + assert torch.allclose(d_ij, torch.norm(r_ij, dim=1, keepdim=True, p=2)) + + +def test_inference_neighborlist_building(): + """Test that NeighborlistBruteNsq and NeighborlistVerletNsq behave identically when building the neighborlist""" + from modelforge.potential.neighbors import ( + NeighborlistBruteNsq, + NeighborlistVerletNsq, + OrthogonalDisplacementFunction, + ) + import torch + + from modelforge.dataset.dataset import NNPInput + + displacement_function = OrthogonalDisplacementFunction() + + positions = torch.tensor( + [[0.0, 0, 0], [1, 0, 0], [3.0, 0, 0], [8, 0, 0]], dtype=torch.float32 + ) + + data = NNPInput( + atomic_numbers=torch.tensor([1, 1, 1, 1], dtype=torch.int64), + positions=positions, + atomic_subsystem_indices=torch.tensor([0, 0, 0, 0], dtype=torch.int64), + per_system_total_charge=torch.tensor([0.0], dtype=torch.float32), + box_vectors=torch.tensor( + [[10, 0, 0], [0, 10, 0], [0, 0, 10]], dtype=torch.float32 + ), + is_periodic=True, + ) + # test to + nlist = NeighborlistBruteNsq( + cutoff=5.0, displacement_function=displacement_function, only_unique_pairs=False + ) + pairs, d_ij, r_ij = nlist(data) + + assert pairs.shape[1] == 12 + + nlist_verlet = NeighborlistVerletNsq( + cutoff=5.0, + displacement_function=displacement_function, + skin=0.5, + only_unique_pairs=False, + ) + + pairs_v, d_ij_v, r_ij_v = nlist_verlet(data) + assert pairs_v.shape[1] == pairs.shape[1] + assert torch.all(pairs_v == pairs) + assert torch.allclose(d_ij_v, d_ij) + assert torch.allclose(r_ij_v, r_ij) + + nlist = NeighborlistBruteNsq( + cutoff=5.0, displacement_function=displacement_function, only_unique_pairs=True + ) + + pairs, d_ij, r_ij = nlist(data) + + assert pairs.shape[1] == 6 + + nlist_verlet = NeighborlistVerletNsq( + cutoff=5.0, + displacement_function=displacement_function, + skin=0.5, + only_unique_pairs=True, + ) + + pairs_v, d_ij_v, r_ij_v = nlist_verlet(data) + assert pairs_v.shape[1] == pairs.shape[1] + assert torch.all(pairs_v == pairs) + assert torch.allclose(d_ij_v, d_ij) + assert torch.allclose(r_ij_v, r_ij) + + nlist = NeighborlistBruteNsq( + cutoff=3.5, displacement_function=displacement_function, only_unique_pairs=False + ) + + pairs, d_ij, r_ij = nlist(data) + + assert pairs.shape[1] == 10 + + assert torch.all(d_ij <= 3.5) + + assert torch.all( + pairs + == torch.tensor( + [[0, 0, 0, 1, 1, 1, 2, 3, 2, 3], [1, 2, 3, 2, 3, 0, 0, 0, 1, 1]] + ) + ) + + assert torch.allclose( + r_ij, + torch.tensor( + [ + [-1.0, 0.0, 0.0], + [-3.0, 0.0, 0.0], + [2.0, 0.0, 0.0], + [-2.0, 0.0, 0.0], + [3.0, 0.0, 0.0], + [1.0, -0.0, -0.0], + [3.0, -0.0, -0.0], + [-2.0, -0.0, -0.0], + [2.0, -0.0, -0.0], + [-3.0, -0.0, -0.0], + ] + ), + ) + + assert torch.allclose( + d_ij, + torch.tensor( + [[1.0], [3.0], [2.0], [2.0], [3.0], [1.0], [3.0], [2.0], [2.0], [3.0]] + ), + ) + + nlist_verlet = NeighborlistVerletNsq( + cutoff=3.5, + displacement_function=displacement_function, + skin=0.5, + only_unique_pairs=False, + ) + + pairs_v, d_ij_v, r_ij_v = nlist_verlet(data) + assert pairs_v.shape[1] == pairs.shape[1] + assert torch.all(pairs_v == pairs) + assert torch.allclose(d_ij_v, d_ij) + assert torch.allclose(r_ij_v, r_ij) + + displacement_function = OrthogonalDisplacementFunction() + + nlist = NeighborlistBruteNsq( + cutoff=5.0, displacement_function=displacement_function, only_unique_pairs=False + ) + + data.is_periodic = False + + pairs, d_ij, r_ij = nlist(data) + + assert pairs.shape[1] == 8 + assert torch.all(d_ij <= 5.0) + + nlist_verlet = NeighborlistVerletNsq( + cutoff=5.0, + displacement_function=displacement_function, + skin=0.5, + only_unique_pairs=False, + ) + + pairs_v, d_ij_v, r_ij_v = nlist_verlet(data) + assert pairs_v.shape[1] == pairs.shape[1] + assert torch.all(pairs_v == pairs) + assert torch.allclose(d_ij_v, d_ij) + assert torch.allclose(r_ij_v, r_ij) + + # test updates to verlet list + + positions = torch.tensor( + [[0.0, 0, 0], [1, 0, 0], [3.0, 0, 0], [8, 0, 0]], dtype=torch.float32 + ) + + +def test_verlet_inference(): + """Test to ensure that the verlet neighborlist properly updates by comparing to brute force neighborlist""" + from modelforge.potential.neighbors import ( + NeighborlistBruteNsq, + NeighborlistVerletNsq, + OrthogonalDisplacementFunction, + ) + import torch + + from modelforge.dataset.dataset import NNPInput + + def return_data(positions, box_length=10, is_periodic=True): + return NNPInput( + atomic_numbers=torch.ones(positions.shape[0], dtype=torch.int64), + positions=positions, + atomic_subsystem_indices=torch.zeros(positions.shape[0], dtype=torch.int64), + per_system_total_charge=torch.tensor([0.0], dtype=torch.float32), + box_vectors=torch.tensor( + [[box_length, 0, 0], [0, box_length, 0], [0, 0, box_length]], + dtype=torch.float32, + ), + is_periodic=is_periodic, + ) + + positions = torch.tensor( + [[2.0, 0, 0], [1.0, 0, 0], [0.0, 0.0, 0]], dtype=torch.float32 + ) + data = return_data(positions) + + displacement_function = OrthogonalDisplacementFunction() + nlist_verlet = NeighborlistVerletNsq( + cutoff=1.5, + displacement_function=displacement_function, + skin=0.5, + only_unique_pairs=True, + ) + + nlist_brute = NeighborlistBruteNsq( + cutoff=1.5, + displacement_function=displacement_function, + only_unique_pairs=True, + ) + + print("first check") + pairs, d_ij, r_ij = nlist_brute(data) + + pairs_v, d_ij_v, r_ij_v = nlist_verlet(data) + + assert nlist_verlet.builds == 1 + assert pairs.shape[1] == 2 + assert torch.all(pairs_v == pairs) + assert torch.allclose(d_ij_v, d_ij) + assert torch.allclose(r_ij_v, r_ij) + + # move one particle father away, but still interacting + positions = torch.tensor( + [[2.2, 0, 0], [1.0, 0, 0], [0.0, 0.0, 0]], dtype=torch.float32 + ) + data = return_data(positions) + + pairs, d_ij, r_ij = nlist_brute(data) + pairs_v, d_ij_v, r_ij_v = nlist_verlet(data) + + # since we didn't move far enough to trigger a rebuild of the verlet list, the results should be the same + assert nlist_verlet.builds == 1 + assert pairs.shape[1] == 2 + assert torch.all(pairs_v == pairs) + assert torch.allclose(d_ij_v, d_ij) + assert torch.allclose(r_ij_v, r_ij) + + # move one particle father away, but still interacting, but enough to trigger a rebuild, + # since rebuild 0.5*skin = 0.25 + positions = torch.tensor( + [[2.3, 0, 0], [1.0, 0, 0], [0.0, 0.0, 0]], dtype=torch.float32 + ) + data = return_data(positions) + + pairs, d_ij, r_ij = nlist_brute(data) + pairs_v, d_ij_v, r_ij_v = nlist_verlet(data) + + assert nlist_verlet.builds == 2 + assert pairs.shape[1] == 2 + assert torch.all(pairs_v == pairs) + assert torch.allclose(d_ij_v, d_ij) + assert torch.allclose(r_ij_v, r_ij) + + # move one particle farther away so it no longer interacts; but less than 0.5*skin = 0.25 since last rebuild, + # so no rebuilding will occur + positions = torch.tensor( + [[2.51, 0, 0], [1.0, 0, 0], [0.0, 0.0, 0]], dtype=torch.float32 + ) + data = return_data(positions) + + pairs, d_ij, r_ij = nlist_brute(data) + pairs_v, d_ij_v, r_ij_v = nlist_verlet(data) + + assert nlist_verlet.builds == 2 + assert pairs.shape[1] == 1 + assert torch.all(pairs_v == pairs) + assert torch.allclose(d_ij_v, d_ij) + assert torch.allclose(r_ij_v, r_ij) + + # move the particle back such that it is interacting, but less than half the skin, so no rebuild + positions = torch.tensor( + [[2.45, 0, 0], [1.0, 0, 0], [0.0, 0.0, 0]], dtype=torch.float32 + ) + data = return_data(positions) + + pairs, d_ij, r_ij = nlist_brute(data) + pairs_v, d_ij_v, r_ij_v = nlist_verlet(data) + + assert nlist_verlet.builds == 2 + assert pairs.shape[1] == 2 + assert torch.all(pairs_v == pairs) + assert torch.allclose(d_ij_v, d_ij) + assert torch.allclose(r_ij_v, r_ij) + + # force a rebuild by changing box_vectors + positions = torch.tensor( + [[2.45, 0, 0], [1.0, 0, 0], [0.0, 0.0, 0]], dtype=torch.float32 + ) + data = return_data(positions, 9) + + pairs, d_ij, r_ij = nlist_brute(data) + pairs_v, d_ij_v, r_ij_v = nlist_verlet(data) + + assert nlist_verlet.builds == 3 + assert pairs.shape[1] == 2 + assert torch.all(pairs_v == pairs) + assert torch.allclose(d_ij_v, d_ij) + assert torch.allclose(r_ij_v, r_ij) + + # force rebuild by changing number of particles; but let's add a particle that doesn't interact + positions = torch.tensor( + [[2.45, 0, 0], [1.0, 0, 0], [0.0, 0.0, 0], [4, 0, 0]], dtype=torch.float32 + ) + data = return_data(positions, 9) + + pairs, d_ij, r_ij = nlist_brute(data) + pairs_v, d_ij_v, r_ij_v = nlist_verlet(data) + + assert nlist_verlet.builds == 4 + assert pairs.shape[1] == 2 + assert torch.all(pairs_v == pairs) + assert torch.allclose(d_ij_v, d_ij) + assert torch.allclose(r_ij_v, r_ij) diff --git a/modelforge/tests/test_parameter_models.py b/modelforge/tests/test_parameter_models.py new file mode 100644 index 00000000..f0d681ac --- /dev/null +++ b/modelforge/tests/test_parameter_models.py @@ -0,0 +1,150 @@ +import pytest + +from pydantic import ValidationError +from modelforge.potential import _Implemented_NNPs + + +def test_dataset_parameter_model(): + + from modelforge.dataset.dataset import DatasetParameters + + # test to ensure we can properly initialize + dataset_parameter_dict = { + "dataset_name": "QM9", + "version_select": "latest", + "num_workers": 4, + "pin_memory": True, + } + + dataset_parameters = DatasetParameters(**dataset_parameter_dict) + + # test the validator on num_workers that asserts it must be greater than 0 + dataset_parameter_dict = { + "dataset_name": "QM9", + "version_select": "latest", + "num_workers": -1, + "pin_memory": True, + } + + with pytest.raises(ValidationError): + dataset_parameters = DatasetParameters(**dataset_parameter_dict) + + # test to ensure error is raised if we do not provide all parameters + dataset_parameter_dict = { + "dataset_name": "QM9", + "version_select": "latest", + "num_workers": 4, + } + + with pytest.raises(ValidationError): + dataset_parameters = DatasetParameters(**dataset_parameter_dict) + + # test to ensure error is raised if we set a wrong type + dataset_parameter_dict = { + "dataset_name": "QM9", + "version_select": "latest", + "num_workers": 4, + "pin_memory": "totally_true", + } + + with pytest.raises(ValidationError): + dataset_parameters = DatasetParameters(**dataset_parameter_dict) + + # we should raise an error if we assign a wrong type to dataset_name + with pytest.raises(ValidationError): + dataset_parameters.dataset_name = 4 + + # check the validator that asserts number of workers must be greater than 0 during assignment + with pytest.raises(ValidationError): + dataset_parameters.num_workers = 0 + + +def test_convert_str_to_unit(): + # Test the validator that will automatically convert a string formated like "1.0 angstrom" to a unit.Quantity + + from modelforge.utils.units import _convert_str_to_unit + from openff.units import unit + + assert _convert_str_to_unit("1.0 angstrom") == unit.Quantity("1.0 angstrom") + assert _convert_str_to_unit(unit.Quantity("1.0 angstrom")) == unit.Quantity( + "1.0 angstrom" + ) + + +@pytest.mark.parametrize( + "potential_name", _Implemented_NNPs.get_all_neural_network_names() +) +def test_potential_parameter_model(potential_name): + from modelforge.tests.data import potential_defaults + + from importlib import resources + import toml + + potential_path = ( + resources.files(potential_defaults) / f"{potential_name.lower()}.toml" + ) + potential_config_dict = toml.load(potential_path) + + from modelforge.potential import _Implemented_NNP_Parameters + + PotentialParameters = ( + _Implemented_NNP_Parameters.get_neural_network_parameter_class(potential_name) + ) + + # test to ensure we can properly initialize + potential_parameters = PotentialParameters(**potential_config_dict["potential"]) + + +def test_runtime_parameter_model(): + from modelforge.train.parameters import RuntimeParameters + from modelforge.tests.data import runtime_defaults + + from importlib import resources + import toml + + runtime_path = resources.files(runtime_defaults) / "runtime.toml" + runtime_config_dict = toml.load(runtime_path) + + # test to ensure we can properly initialize + runtime_parameters = RuntimeParameters(**runtime_config_dict["runtime"]) + + with pytest.raises(ValidationError): + runtime_parameters.number_of_nodes = -1 + + with pytest.raises(ValidationError): + runtime_parameters.devices = -1 + + with pytest.raises(ValidationError): + runtime_parameters.devices = [-1, 0] + + with pytest.raises(ValidationError): + runtime_parameters.accelerator = "not_a_valid_accelerator" + + +def test_training_parameter_model(): + from modelforge.train.parameters import TrainingParameters + from modelforge.tests.data import training_defaults + + from importlib import resources + import toml + + training_path = resources.files(training_defaults) / "default.toml" + training_config_dict = toml.load(training_path) + + # test to ensure we can properly initialize + training_parameters = TrainingParameters(**training_config_dict["training"]) + + # this will throw an error because the split should sum to 1 + with pytest.raises(ValidationError): + training_parameters.splitting_strategy.dataset_split = [0.1, 0.1, 0.1] + + # this will throw an error because the split should be of length 3 + with pytest.raises(ValidationError): + training_parameters.splitting_strategy.dataset_split = [0.7, 0.1, 0.1, 0.1] + + # this will throw an error because the datafile has 1 entries for the loss_components dictionary + with pytest.raises(ValidationError): + training_parameters.loss_parameter.loss_components = [ + "per_system_energy", + "per_atom_force", + ] diff --git a/modelforge/tests/test_physnet.py b/modelforge/tests/test_physnet.py index 5e860c06..4c811d2f 100644 --- a/modelforge/tests/test_physnet.py +++ b/modelforge/tests/test_physnet.py @@ -1,44 +1,36 @@ -def test_init(): - - from modelforge.potential.physnet import PhysNet +from typing import Optional +import pytest - from modelforge.tests.test_models import load_configs +from modelforge.tests.helper_functions import setup_potential_for_test - # read default parameters - config = load_configs(f"physnet", "qm9") - model = PhysNet( - **config["potential"]["core_parameter"], - postprocessing_parameter=config["potential"]["postprocessing_parameter"], - ) +@pytest.fixture(scope="session") +def prep_temp_dir(tmp_path_factory): + fn = tmp_path_factory.mktemp("test_physnet_temp") + return fn -def test_forward(single_batch_with_batchsize_64): - import torch - from modelforge.potential.physnet import PhysNet +def test_init(): - # read default parameters - from modelforge.tests.test_models import load_configs + model = setup_potential_for_test("physnet", "training") + assert model is not None, "PhysNet model should be initialized." - # read default parameters - config = load_configs(f"physnet", "qm9") - # Extract parameters - config["potential"]["core_parameter"]["number_of_modules"] = 1 - config["potential"]["core_parameter"]["number_of_interaction_residual"] = 1 +def test_forward(single_batch_with_batchsize, prep_temp_dir): + import torch - model = PhysNet( - **config["potential"]["core_parameter"], - postprocessing_parameter=config["potential"]["postprocessing_parameter"], - ) - model = model.to(torch.float32) + model = setup_potential_for_test("physnet", "training") print(model) - yhat = model(single_batch_with_batchsize_64.nnp_input.to(dtype=torch.float32)) + batch = single_batch_with_batchsize( + batch_size=64, dataset_name="QM9", local_cache_dir=str(prep_temp_dir) + ) + + yhat = model(batch.nnp_input.to_dtype(dtype=torch.float32)) def test_compare_representation(): - # This test compares the RBF calculation of the original - # PhysNet implemntation against the SAKE/PhysNet implementation in modelforge + # This test compares the RBF calculation of the original PhysNet + # implemntation against the SAKE/PhysNet implementation in modelforge # # NOTE: input in PhysNet is expected in angstrom, in contrast to modelforge which expects input in nanomter import numpy as np @@ -47,16 +39,16 @@ def test_compare_representation(): # set up test parameters number_of_radial_basis_functions = K = 20 - _max_distance_in_nanometer = 0.5 + _max_distance = unit.Quantity(5, unit.angstrom) ############################# # RBF comparision ############################# # Initialize the rbf class - from modelforge.potential.utils import PhysNetRadialBasisFunction + from modelforge.potential import PhysNetRadialBasisFunction rbf = PhysNetRadialBasisFunction( number_of_radial_basis_functions, - max_distance=_max_distance_in_nanometer * unit.nanometer, + max_distance=_max_distance.to(unit.nanometer).m, ) # compare the rbf output @@ -65,7 +57,7 @@ def test_compare_representation(): from .precalculated_values import provide_reference_for_test_physnet_test_rbf reference_rbf = provide_reference_for_test_physnet_test_rbf() - D = np.array([[1.0394776], [3.375541]], dtype=np.float32) + D = np.array([[1.0394776], [3.375541]], dtype=np.float32) / 10 - calculated_rbf = rbf(torch.tensor(D / 10)) + calculated_rbf = rbf(torch.tensor(D)) assert np.allclose(np.flip(reference_rbf.squeeze(), axis=1), calculated_rbf.numpy()) diff --git a/modelforge/tests/test_potentials.py b/modelforge/tests/test_potentials.py new file mode 100644 index 00000000..cb69e05a --- /dev/null +++ b/modelforge/tests/test_potentials.py @@ -0,0 +1,1113 @@ +from typing import Literal + +import pytest +import torch +from openff.units import unit + +from modelforge.dataset import _ImplementedDatasets +from modelforge.potential import NeuralNetworkPotentialFactory, _Implemented_NNPs +from modelforge.tests.helper_functions import ( + setup_potential_for_test, + _add_electrostatic_to_predicted_properties, + _add_per_atom_charge_to_predicted_properties, + _add_per_atom_charge_to_properties_to_process, +) +from modelforge.utils.io import import_ +from modelforge.utils.misc import load_configs_into_pydantic_models + + +@pytest.fixture(scope="session") +def prep_temp_dir(tmp_path_factory): + fn = tmp_path_factory.mktemp("test_models_temp") + return fn + + +def initialize_model( + simulation_environment: Literal["PyTorch", "JAX"], config, jit: bool +): + """Initialize the model based on the simulation environment and configuration.""" + return NeuralNetworkPotentialFactory.generate_potential( + simulation_environment=simulation_environment, + potential_parameter=config["potential"], + jit=jit, + use_training_mode_neighborlist=True, + ) + + +def prepare_input_for_model(nnp_input, model): + """Prepare the input for the model based on the simulation environment.""" + if "JAX" in str(type(model)): + from modelforge.jax import convert_NNPInput_to_jax + + return convert_NNPInput_to_jax(nnp_input) + return nnp_input + + +def validate_output_shapes(output, nr_of_mols: int, energy_expression: str): + """Validate the output shapes to ensure they are correct.""" + assert len(output["per_system_energy"]) == nr_of_mols + assert "per_atom_energy" in output + if energy_expression == "short_range_and_long_range_electrostatic": + assert "per_atom_charge" in output + assert "per_atom_charge_uncorrected" in output + assert "electrostatic_energy" in output + + +def validate_charge_conservation( + per_system_total_charge: torch.Tensor, + per_system_total_charge_uncorrected: torch.Tensor, + per_system_total_charge_from_dataset: torch.Tensor, + model_name: str, +): + """Ensure charge conservation by validating the corrected charges.""" + + if "PhysNet".lower() in model_name.lower(): + print( + "Physnet starts with all zero partial charges" + ) # NOTE: I am not sure if this is correct + else: + assert not torch.allclose( + per_system_total_charge, per_system_total_charge_uncorrected + ) + assert torch.allclose( + per_system_total_charge_from_dataset.to(torch.float32), + per_system_total_charge, + atol=1e-5, + ) + + +from typing import Dict + + +def validate_per_atom_and_per_system_properties(output: Dict[str, torch.Tensor]): + """Ensure that the total energy is the sum of atomic energies.""" + assert torch.allclose( + output["per_system_energy"][0], + output["per_atom_energy"][0:5].sum(dim=0), + atol=1e-5, + ) + assert torch.allclose( + output["per_system_energy"][1], + output["per_atom_energy"][5:9].sum(dim=0), + atol=1e-5, + ) + + +def validate_chemical_equivalence(output): + """Ensure that chemically equivalent hydrogens have equal energies.""" + assert torch.allclose( + output["per_atom_energy"][1:4], output["per_atom_energy"][1], atol=1e-4 + ) + assert torch.allclose( + output["per_atom_energy"][6:8], output["per_atom_energy"][6], atol=1e-4 + ) + + +def retrieve_molecular_charges(output, atomic_subsystem_indices): + """Retrieve per-molecule charge from per-atom charges.""" + per_system_total_charge = torch.zeros_like(output["per_system_energy"]).index_add_( + 0, atomic_subsystem_indices, output["per_atom_charge"] + ) + per_system_total_charge_uncorrected = torch.zeros_like( + output["per_system_energy"] + ).index_add_(0, atomic_subsystem_indices, output["per_atom_charge_uncorrected"]) + return per_system_total_charge, per_system_total_charge_uncorrected + + +def convert_to_pytorch_if_needed(output, nnp_input, model): + """Convert output to PyTorch tensors if the model is in JAX.""" + if "JAX" in str(type(model)): + convert_to_pyt = import_("pytorch2jax").pytorch2jax.convert_to_pyt + output["per_system_energy"] = convert_to_pyt(output["per_system_energy"]) + output["per_atom_energy"] = convert_to_pyt(output["per_atom_energy"]) + + if "per_atom_charge" in output: + output["per_atom_charge"] = convert_to_pyt(output["per_atom_charge"]) + if "per_system_total_charge" in output: + output["per_system_total_charge"] = convert_to_pyt( + output["per_system_total_charge"] + ).to(torch.float32) + + atomic_subsystem_indices = convert_to_pyt(nnp_input.atomic_subsystem_indices) + else: + atomic_subsystem_indices = nnp_input.atomic_subsystem_indices + return output, atomic_subsystem_indices + + +def test_electrostatics(): + from modelforge.potential.processing import CoulombPotential + + e_elec = CoulombPotential(1.0) + per_atom_charge = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]) + # FIXME: this thest has to be implemented + + +""" +pairlist = PairListOutputs( +pair_indices=torch.tensor([[0,0,0,0,1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4], + [1,2,3,4,0,2,3,4,0,1,3,4,0,1,2,4,0,1,2,3]]), +d_ij = torch.tensor([ + ) + + pairwise_properties = {} + pairwise_properties["maximum_interaction_radius"] = + """ + + +@pytest.mark.parametrize( + "potential_name", _Implemented_NNPs.get_all_neural_network_names() +) +def test_JAX_wrapping(potential_name, single_batch_with_batchsize, prep_temp_dir): + + batch = single_batch_with_batchsize( + batch_size=1, dataset_name="QM9", local_cache_dir=str(prep_temp_dir) + ) + + # read default parameters + potential = setup_potential_for_test( + use="inference", + potential_seed=42, + potential_name=potential_name, + simulation_environment="JAX", + local_cache_dir=str(prep_temp_dir), + ) + from modelforge.jax import convert_NNPInput_to_jax + + nnp_input = convert_NNPInput_to_jax(batch.nnp_input) + out = potential(nnp_input)["per_system_energy"] + import jax + + assert "JAX" in str(type(potential)) + + grad_fn = jax.grad(lambda pos: out.sum()) # Create a gradient function + forces = -grad_fn( + nnp_input.positions + ) # Evaluate gradient function and apply negative sign + + +@pytest.mark.parametrize( + "potential_name", _Implemented_NNPs.get_all_neural_network_names() +) +def test_model_factory(potential_name, prep_temp_dir): + # inference model + potential = setup_potential_for_test( + use="inference", + potential_seed=42, + potential_name=potential_name, + simulation_environment="PyTorch", + local_cache_dir=str(prep_temp_dir), + ) + assert ( + potential_name.upper() in str(type(potential.core_network)).upper() + or "JAX" in str(type(potential)).upper() + ) + potential = setup_potential_for_test( + use="inference", + potential_seed=42, + potential_name=potential_name, + simulation_environment="PyTorch", + jit=True, + use_default_dataset_statistic=False, + local_cache_dir=str(prep_temp_dir), + ) + + # trainers model + trainer = setup_potential_for_test( + use="training", + potential_seed=42, + potential_name=potential_name, + simulation_environment="PyTorch", + local_cache_dir=str(prep_temp_dir), + ) + assert ( + potential_name.upper() in str(type(trainer.core_network)).upper() + or "JAX" in str(type(trainer)).upper() + ) + + +@pytest.mark.parametrize( + "potential_name", _Implemented_NNPs.get_all_neural_network_names() +) +def test_energy_scaling_and_offset( + potential_name, single_batch_with_batchsize, prep_temp_dir +): + from modelforge.potential.potential import NeuralNetworkPotentialFactory + + # read default parameters + config = load_configs_into_pydantic_models(f"{potential_name.lower()}", "qm9") + + config["runtime"].local_cache_dir = str(prep_temp_dir) + + # inference model + trainer = NeuralNetworkPotentialFactory.generate_trainer( + potential_parameter=config["potential"], + training_parameter=config["training"], + dataset_parameter=config["dataset"], + runtime_parameter=config["runtime"], + ) + + batch = single_batch_with_batchsize( + batch_size=1, dataset_name="QM9", local_cache_dir=str(prep_temp_dir) + ) + methane = batch.nnp_input + + # load dataset statistic + import toml + + dataset_statistic = toml.load(trainer.datamodule.dataset_statistic_filename) + # -------------------------------# + # initialize model without any postprocessing + # -------------------------------# + + potential = NeuralNetworkPotentialFactory.generate_potential( + potential_parameter=config["potential"], + potential_seed=42, + ) + output_no_postprocessing = potential(methane) + # -------------------------------# + # Scale output + potential = NeuralNetworkPotentialFactory.generate_potential( + potential_parameter=config["potential"], + dataset_statistic=trainer.dataset_statistic, + potential_seed=42, + ) + scaled_output = potential(methane) + + # make sure that the scaled output equals the unscaled output + + mean = unit.Quantity( + dataset_statistic["training_dataset_statistics"]["per_atom_energy_mean"] + ).m + stddev = unit.Quantity( + dataset_statistic["training_dataset_statistics"]["per_atom_energy_stddev"] + ).m + + # NOTE: only the per_system_energy is scaled + compare_to = output_no_postprocessing["per_atom_energy"] * stddev + mean + assert torch.allclose(scaled_output["per_system_energy"], compare_to.sum()) + + +""" +tensor([[-406.9472], + [-397.2831], + [-397.2831], + [-397.2831], + [-397.2831]]) +tensor([[-402.9324], + [-401.2677], + [-401.2677], + [-401.2683], + [-401.2684]]) +""" + + +@pytest.mark.parametrize( + "potential_name", _Implemented_NNPs.get_all_neural_network_names() +) +def test_state_dict_saving_and_loading(potential_name, prep_temp_dir): + import torch + + # give this a unique name so that we can run tests in parallel + file_path = f"{str(prep_temp_dir)}/{potential_name.lower()}_tsdsal_potential.pth" + from modelforge.potential import NeuralNetworkPotentialFactory + + # read default parameters + config = load_configs_into_pydantic_models(f"{potential_name.lower()}", "qm9") + + config["runtime"].local_cache_dir = str(prep_temp_dir) + # ------------------------------------------------------------- # + # Use case 0: + # train a model, save the state_dict and load it again + trainer = NeuralNetworkPotentialFactory.generate_trainer( + potential_parameter=config["potential"], + training_parameter=config["training"], + runtime_parameter=config["runtime"], + dataset_parameter=config["dataset"], + ) + torch.save(trainer.lightning_module.state_dict(), file_path) + trainer.lightning_module.load_state_dict(torch.load(file_path)) + + # ------------------------------------------------------------- # + # Use case 1 + # generate a new trainer and load from a state_dict file + trainer2 = NeuralNetworkPotentialFactory.generate_trainer( + potential_parameter=config["potential"], + training_parameter=config["training"], + runtime_parameter=config["runtime"], + dataset_parameter=config["dataset"], + ) + trainer2.lightning_module.load_state_dict(torch.load(file_path)) + + # ------------------------------------------------------------- # + # Use case 2: + # load the model in inference mode + potential = NeuralNetworkPotentialFactory.generate_potential( + simulation_environment="PyTorch", + potential_parameter=config["potential"], + ) + potential.load_state_dict(torch.load(file_path)) + + +def test_loading_from_checkpoint_file(): + from importlib import resources + from modelforge.tests import data + + # checkpoint file is saved in tests/data + ckpt_file = str(resources.files(data) / "best_SchNet-PhAlkEthOH-epoch=00.ckpt") + print(ckpt_file) + + from modelforge.potential.potential import load_inference_model_from_checkpoint + + potential = load_inference_model_from_checkpoint(ckpt_file) + assert potential is not None + + +@pytest.mark.parametrize( + "potential_name", _Implemented_NNPs.get_all_neural_network_names() +) +def test_dataset_statistic(potential_name, prep_temp_dir): + # Test that the scaling parmaeters are propagated from the dataset to the + # runtime_defaults model and then via the state_dict to the inference model + + import numpy as np + import torch + from openff.units import unit + + from modelforge.dataset.dataset import DataModule + from modelforge.dataset.utils import FirstComeFirstServeSplittingStrategy + + # read default parameters + config = load_configs_into_pydantic_models(f"{potential_name.lower()}", "qm9") + + # Extract parameters + potential_parameter = config["potential"] + training_parameter = config["training"] + dataset_parameter = config["dataset"] + runtime_parameter = config["runtime"] + + runtime_parameter.local_cache_dir = str(prep_temp_dir) + + # test the self energy calculation on the QM9 dataset + dataset = DataModule( + name="QM9", + batch_size=64, + version_select="nc_1000_v0", + splitting_strategy=FirstComeFirstServeSplittingStrategy(), + remove_self_energies=True, + regression_ase=False, + regenerate_dataset_statistic=True, + local_cache_dir=str(prep_temp_dir), + ) + dataset.prepare_data() + dataset.setup() + + # load dataset stastics from file + from modelforge.potential.utils import read_dataset_statistics + + dataset_statistic = read_dataset_statistics(dataset.dataset_statistic_filename) + # extract value to compare against + toml_E_i_mean = unit.Quantity( + dataset_statistic["training_dataset_statistics"]["per_atom_energy_mean"] + ).m + + trainer = NeuralNetworkPotentialFactory.generate_trainer( + potential_parameter=potential_parameter, + training_parameter=training_parameter, + dataset_parameter=dataset_parameter, + runtime_parameter=runtime_parameter, + ) + # check that the per_atom_energy_mean is the same as in the dataset statistics + assert np.isclose( + toml_E_i_mean, + unit.Quantity( + trainer.dataset_statistic["training_dataset_statistics"][ + "per_atom_energy_mean" + ] + ).m, + ) + # give this a unique filename based on potential and the test we are in so we can run test in parallel + file_path = f"{str(prep_temp_dir)}/{potential_name.lower()}_tsd_potential.pth" + + torch.save(trainer.lightning_module.state_dict(), file_path) + + # NOTE: we are passing dataset statistics explicit to the constructor + # this is not saved with the state_dict + potential = NeuralNetworkPotentialFactory.generate_potential( + simulation_environment="PyTorch", + potential_parameter=config["potential"], + dataset_statistic=dataset_statistic, + ) + potential.load_state_dict(torch.load(file_path)) + + assert np.isclose( + toml_E_i_mean, + unit.Quantity( + potential.postprocessing.dataset_statistic["training_dataset_statistics"][ + "per_atom_energy_mean" + ] + ).m, + ) + + +@pytest.mark.parametrize( + "potential_name", _Implemented_NNPs.get_all_neural_network_names() +) +def test_energy_between_simulation_environments( + potential_name, single_batch_with_batchsize, prep_temp_dir +): + # compare that the energy is the same for the JAX and PyTorch Model + import numpy as np + + batch = single_batch_with_batchsize( + batch_size=64, dataset_name="QM9", local_cache_dir=str(prep_temp_dir) + ) + nnp_input = batch.nnp_input + # test the forward pass through each of the models + # cast input and model to torch.float64 + # read default parameters + potential = setup_potential_for_test( + use="inference", + potential_seed=42, + potential_name=potential_name, + simulation_environment="PyTorch", + local_cache_dir=str(prep_temp_dir), + ) + output_torch = potential(nnp_input)["per_system_energy"] + + potential = setup_potential_for_test( + use="inference", + potential_seed=42, + potential_name=potential_name, + simulation_environment="JAX", + local_cache_dir=str(prep_temp_dir), + ) + from modelforge.jax import convert_NNPInput_to_jax + + nnp_input = convert_NNPInput_to_jax(batch.nnp_input) + output_jax = potential(nnp_input)["per_system_energy"] + + # test tat we get an energie per molecule + assert np.isclose(output_torch.sum().detach().numpy(), output_jax.sum()) + + +@pytest.mark.parametrize( + "potential_name", _Implemented_NNPs.get_all_neural_network_names() +) +@pytest.mark.parametrize("dataset_name", _ImplementedDatasets.get_all_dataset_names()) +def test_forward_pass_with_all_datasets( + potential_name, dataset_name, datamodule_factory, prep_temp_dir +): + """Test forward pass with all datasets.""" + import toml + import torch + + from modelforge.potential.potential import NeuralNetworkPotentialFactory + + # -------------------------------# + # setup dataset + # use a subset of the SPICE2 dataset for ANI2x + if dataset_name.lower().startswith("spice"): + print("using subset") + dataset = datamodule_factory( + dataset_name=dataset_name, + version_select="nc_1000_v0_HCNOFClS", + local_cache_dir=str(prep_temp_dir), + ) + else: + dataset = datamodule_factory( + dataset_name=dataset_name, local_cache_dir=str(prep_temp_dir) + ) + + dataset_statistic = toml.load(dataset.dataset_statistic_filename) + train_dataloader = dataset.train_dataloader() + batch = next(iter(train_dataloader)) + # -------------------------------# + # setup model + config = load_configs_into_pydantic_models( + f"{potential_name.lower()}", dataset_name.lower() + ) + potential = NeuralNetworkPotentialFactory.generate_potential( + potential_parameter=config["potential"], + dataset_statistic=dataset_statistic, + use_training_mode_neighborlist=True, + jit=False, + ) + # -------------------------------# + # test the forward pass through each of the models + output = potential(batch.nnp_input) + + # test that the output has the following keys and following dim + assert "per_system_energy" in output + assert "per_atom_energy" in output + + assert ( + output["per_system_energy"].shape == batch.metadata.per_system_energy.shape + ) # per system + assert ( + output["per_atom_energy"].shape[0] == batch.metadata.per_atom_force.shape[0] + ) # per atom + + pair_list = batch.nnp_input.pair_list + # pairlist is in ascending order in row 0 + assert torch.all(pair_list[0, 1:] >= pair_list[0, :-1]) + + +@pytest.mark.parametrize( + "potential_name", _Implemented_NNPs.get_all_neural_network_names() +) +def test_jit(potential_name, single_batch_with_batchsize, prep_temp_dir): + # setup dataset + batch = single_batch_with_batchsize( + batch_size=1, dataset_name="qm9", local_cache_dir=str(prep_temp_dir) + ) + nnp_input = batch.nnp_input + + # -------------------------------# + # setup model + config = load_configs_into_pydantic_models(f"{potential_name.lower()}", "qm9") + # test the forward pass through each of the models + potential = NeuralNetworkPotentialFactory.generate_potential( + potential_parameter=config["potential"], + ) + potential = torch.jit.script(potential) + # -------------------------------# + potential(nnp_input) + + +@pytest.mark.parametrize("dataset_name", ["QM9"]) +@pytest.mark.parametrize( + "potential_name", _Implemented_NNPs.get_all_neural_network_names() +) +@pytest.mark.parametrize("mode", ["training", "inference"]) +def test_chemical_equivalency( + dataset_name, potential_name, mode, single_batch_with_batchsize, prep_temp_dir +): + nnp_input = single_batch_with_batchsize( + 32, dataset_name, str(prep_temp_dir) + ).nnp_input + + potential = setup_potential_for_test( + potential_name, + mode, + potential_seed=42, + local_cache_dir=str(prep_temp_dir), + ) + + output = potential(nnp_input) + validate_chemical_equivalence(output) + validate_per_atom_and_per_system_properties(output) + + +@pytest.mark.parametrize("dataset_name", ["QM9"]) +@pytest.mark.parametrize("potential_name", ["SchNet"]) +def test_different_neighborlists_for_inference( + dataset_name, potential_name, single_batch_with_batchsize, prep_temp_dir +): + + # NOTE: the training pairlist only works for a batchsize of 1 + nnp_input = single_batch_with_batchsize( + 1, dataset_name, str(prep_temp_dir) + ).nnp_input + + potential = setup_potential_for_test( + potential_name, + "inference", + potential_seed=42, + use_training_mode_neighborlist=True, + local_cache_dir=str(prep_temp_dir), + ) + + output_1 = potential(nnp_input) + + potential = setup_potential_for_test( + potential_name, + "inference", + potential_seed=42, + use_training_mode_neighborlist=False, + local_cache_dir=str(prep_temp_dir), + ) + + output_2 = potential(nnp_input) + + assert torch.allclose(output_1["per_system_energy"], output_2["per_system_energy"]) + + +@pytest.mark.parametrize("dataset_name", ["QM9"]) +@pytest.mark.parametrize( + "energy_expression", + [ + "short_range", + "short_range_and_long_range_electrostatic", + ], +) +@pytest.mark.parametrize("potential_name", ["SchNet"]) +@pytest.mark.parametrize("simulation_environment", ["PyTorch"]) +@pytest.mark.parametrize("jit", [False]) +def test_multiple_output_heads( + dataset_name, + energy_expression, + potential_name, + simulation_environment, + single_batch_with_batchsize, + jit, + prep_temp_dir, +): + """Test models with multiple output heads.""" + # Get input and set up model + nnp_input = single_batch_with_batchsize( + 32, dataset_name, str(prep_temp_dir) + ).nnp_input + config = load_configs_into_pydantic_models(f"{potential_name.lower()}", "qm9") + config["runtime"].local_cache_dir = str(prep_temp_dir) + # Modify the config based on the energy expression + config = _add_per_atom_charge_to_predicted_properties(config) + if energy_expression == "short_range_and_long_range_electrostatic": + config = _add_per_atom_charge_to_properties_to_process(config) + config = _add_electrostatic_to_predicted_properties(config) + + nr_of_mols = nnp_input.atomic_subsystem_indices.unique().shape[0] + model = initialize_model(simulation_environment, config, jit) + + # Perform the forward pass through the model + output = model(nnp_input) + + # Validate outputs + validate_output_shapes(output, nr_of_mols, energy_expression) + validate_chemical_equivalence(output) + validate_per_atom_and_per_system_properties(output) + + # Test charge correction + if energy_expression == "short_range_and_long_range_electrostatic": + per_system_total_charge, per_system_total_charge_uncorrected = ( + retrieve_molecular_charges(output, nnp_input.atomic_subsystem_indices) + ) + validate_charge_conservation( + per_system_total_charge, + per_system_total_charge_uncorrected, + output["per_system_total_charge"], + potential_name, + ) + + +@pytest.mark.parametrize("dataset_name", ["QM9"]) +@pytest.mark.parametrize( + "potential_name", _Implemented_NNPs.get_all_neural_network_names() +) +@pytest.mark.parametrize("simulation_environment", ["JAX", "PyTorch"]) +def test_forward_pass( + dataset_name, + potential_name, + simulation_environment, + single_batch_with_batchsize, + prep_temp_dir, +): + # this test sends a single batch from different datasets through the model + + # get input and set up model + nnp_input = single_batch_with_batchsize( + 64, dataset_name, str(prep_temp_dir) + ).nnp_input + nr_of_mols = nnp_input.atomic_subsystem_indices.unique().shape[0] + + potential = setup_potential_for_test( + potential_name, + "inference", + potential_seed=42, + use_training_mode_neighborlist=True, + simulation_environment=simulation_environment, + local_cache_dir=str(prep_temp_dir), + ) + nnp_input = prepare_input_for_model(nnp_input, potential) + + # perform the forward pass through each of the models + output = potential(nnp_input) + + # validate the output + validate_output_shapes(output, nr_of_mols, "short_range") + output, atomic_subsystem_indices = convert_to_pytorch_if_needed( + output, nnp_input, potential + ) + validate_chemical_equivalence(output) + + +import os + +IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" + + +@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="torchviz is not installed") +@pytest.mark.parametrize( + "potential_name", _Implemented_NNPs.get_all_neural_network_names() +) +def test_vis(potential_name, single_batch_with_batchsize, prep_temp_dir): + batch = single_batch_with_batchsize( + batch_size=32, dataset_name="SPICE2", local_cache_dir=str(prep_temp_dir) + ) + nnp_input = batch.nnp_input + from modelforge.utils.vis import visualize_model + + visualize_model(nnp_input, potential_name, str(prep_temp_dir)) + + +@pytest.mark.parametrize( + "potential_name", _Implemented_NNPs.get_all_neural_network_names() +) +def test_calculate_energies_and_forces( + potential_name, single_batch_with_batchsize, prep_temp_dir +): + """ + Test the calculation of energies and forces for a molecule. + """ + import torch + + batch = single_batch_with_batchsize( + batch_size=32, dataset_name="SPICE2", local_cache_dir=str(prep_temp_dir) + ) + nnp_input = batch.nnp_input + + # read default parameters + trainer = setup_potential_for_test( + potential_name, + "training", + potential_seed=42, + local_cache_dir=str(prep_temp_dir), + ) + # get energy and force + E_training = trainer(nnp_input)["per_system_energy"] + F_training = -torch.autograd.grad( + E_training.sum(), nnp_input.positions, create_graph=True, retain_graph=True + )[0] + + # compare to inference model + potential = setup_potential_for_test( + potential_name, + "inference", + potential_seed=42, + use_training_mode_neighborlist=True, + jit=False, + local_cache_dir=str(prep_temp_dir), + ) + + # get energy and force + E_inference = potential(nnp_input)["per_system_energy"] + F_inference = -torch.autograd.grad( + E_inference.sum(), nnp_input.positions, create_graph=True, retain_graph=True + )[0] + + print(f"Energy training: {E_training}") + print(f"Energy inference: {E_inference}") + + # make sure that dimension are as expected + nr_of_mols = nnp_input.atomic_subsystem_indices.unique().shape[0] + nr_of_atoms_per_batch = nnp_input.atomic_subsystem_indices.shape[0] + + assert E_inference.shape == (nr_of_mols, 1) # per system + assert F_inference.shape == (nr_of_atoms_per_batch, 3) # per atom + + # make sure that both agree on E and F + assert torch.allclose(E_inference, E_training, atol=1e-4) + assert torch.allclose(F_inference, F_training, atol=1e-4) + + # now compare agains the compiled inference model using the neighborlist + # optimized for MD. NOTE: this requires to reduce the batch size to 1 + # since the neighborlist is not batched + + # reduce batchsize + batch = single_batch_with_batchsize( + batch_size=1, dataset_name="SPICE2", local_cache_dir=str(prep_temp_dir) + ) + nnp_input = batch.nnp_input + + # get the inference model with inference neighborlist and compilre + # everything + potential = setup_potential_for_test( + potential_name, + "inference", + potential_seed=42, + use_training_mode_neighborlist=False, + jit=True, + local_cache_dir=str(prep_temp_dir), + ) + + # get energy and force + E_inference = potential(nnp_input)["per_system_energy"] + F_inference = -torch.autograd.grad( + E_inference.sum(), nnp_input.positions, create_graph=True, retain_graph=True + )[0] + # get energy and force + E_training = potential(nnp_input)["per_system_energy"] + F_training = -torch.autograd.grad( + E_training.sum(), nnp_input.positions, create_graph=True, retain_graph=True + )[0] + + nr_of_mols = nnp_input.atomic_subsystem_indices.unique().shape[0] + nr_of_atoms_per_batch = nnp_input.atomic_subsystem_indices.shape[0] + + assert E_inference.shape == (nr_of_mols, 1) # per system + assert F_inference.shape == (nr_of_atoms_per_batch, 3) # per atom + + # make sure that both agree on E and F + assert torch.allclose(E_inference, E_training, atol=1e-4) + assert torch.allclose(F_inference, F_training, atol=1e-4) + + +def get_nr_of_mols(nnp_input): + import torch + import jax + import jax.numpy as jnp + + atomic_subsystem_indices = nnp_input.atomic_subsystem_indices + + if isinstance(atomic_subsystem_indices, torch.Tensor): + unique_indices = torch.unique(atomic_subsystem_indices) + nr_of_mols = unique_indices.shape[0] + + elif isinstance(atomic_subsystem_indices, jax.Array): + unique_indices = jnp.unique(atomic_subsystem_indices) + nr_of_mols = unique_indices.shape[0] + + else: + raise TypeError("Unsupported type. Expected a PyTorch tensor or a JAX array.") + + return nr_of_mols + + +@pytest.mark.parametrize( + "potential_name", _Implemented_NNPs.get_all_neural_network_names() +) +def test_calculate_energies_and_forces_with_jax( + potential_name, single_batch_with_batchsize, prep_temp_dir +): + """ + Test the calculation of energies and forces for a molecule. + """ + import torch + from modelforge.jax import convert_NNPInput_to_jax + + # get input and set up model + batch = single_batch_with_batchsize( + batch_size=1, dataset_name="QM9", local_cache_dir=str(prep_temp_dir) + ) + + # conver tinput to jax + nnp_input = convert_NNPInput_to_jax(batch.nnp_input) + + potential = setup_potential_for_test( + potential_name, + "inference", + potential_seed=42, + use_training_mode_neighborlist=False, + jit=False, + simulation_environment="JAX", + local_cache_dir=str(prep_temp_dir), + ) + + # forward pass + result = potential(nnp_input)["per_system_energy"] + assert result.shape == batch.metadata.per_system_energy.shape + + from modelforge.utils.io import import_ + + jax = import_("jax") + + grad_fn = jax.grad(lambda pos: result.sum()) # Create a gradient function + forces = -grad_fn( + nnp_input.positions + ) # Evaluate gradient function and apply negative sign + + # test output shapes + nr_of_mols = get_nr_of_mols(nnp_input) + nr_of_atoms_per_batch = nnp_input.atomic_subsystem_indices.shape[0] + assert forces.shape == batch.metadata.per_atom_force.shape + + +@pytest.mark.parametrize( + "potential_name", _Implemented_NNPs.get_all_neural_network_names() +) +def test_casting(potential_name, single_batch_with_batchsize, prep_temp_dir): + # test dtype casting + import torch + + batch = single_batch_with_batchsize( + batch_size=64, dataset_name="QM9", local_cache_dir=str(prep_temp_dir) + ) + batch_ = batch.to_dtype(dtype=torch.float64) + assert batch_.nnp_input.positions.dtype == torch.float64 + batch_ = batch_.to_dtype(dtype=torch.float32) + assert batch_.nnp_input.positions.dtype == torch.float32 + + nnp_input = batch.nnp_input.to_dtype(dtype=torch.float64) + assert nnp_input.positions.dtype == torch.float64 + nnp_input = batch.nnp_input.to_dtype(dtype=torch.float32) + assert nnp_input.positions.dtype == torch.float32 + nnp_input = batch.metadata.to_dtype(dtype=torch.float64) + + # cast input and model to torch.float64 + # read default parameters + config = load_configs_into_pydantic_models(f"{potential_name.lower()}", "qm9") + + potential = NeuralNetworkPotentialFactory.generate_potential( + simulation_environment="PyTorch", + potential_parameter=config["potential"], + use_training_mode_neighborlist=True, # can handel batched data + ) + model = potential.to(dtype=torch.float64) + nnp_input = batch.to_dtype(dtype=torch.float64).nnp_input + + potential(nnp_input) + + # cast input and model to torch.float64 + potential = NeuralNetworkPotentialFactory.generate_potential( + simulation_environment="PyTorch", + potential_parameter=config["potential"], + use_training_mode_neighborlist=True, # can handel batched data + ) + potential = potential.to(dtype=torch.float32) + nnp_input = batch.to_dtype(dtype=torch.float32).nnp_input + + potential(nnp_input) + + +@pytest.mark.parametrize( + "potential_name", _Implemented_NNPs.get_all_neural_network_names() +) +@pytest.mark.parametrize("simulation_environment", ["PyTorch"]) +def test_equivariant_energies_and_forces( + potential_name, + simulation_environment, + single_batch_with_batchsize, + equivariance_utils, + prep_temp_dir, +): + """ + Test the calculation of energies and forces for a molecule. + NOTE: test will be adapted once we have a trained model. + """ + import torch + + precision = torch.float64 + simulation_environment: Literal["PyTorch", "JAX"] + + # initialize the models + potential = setup_potential_for_test( + use="inference", + potential_seed=42, + potential_name=potential_name, + simulation_environment=simulation_environment, + local_cache_dir=str(prep_temp_dir), + ).to(dtype=precision) + + # define the symmetry operations + translation, rotation, reflection = equivariance_utils + # define the tolerance + atol = 1e-1 + + # ------------------- # + # start the test + # reference values + nnp_input = single_batch_with_batchsize( + batch_size=64, dataset_name="QM9", local_cache_dir=str(prep_temp_dir) + ).nnp_input.to_dtype(dtype=precision) + + reference_result = potential(nnp_input)["per_system_energy"] + reference_forces = -torch.autograd.grad( + reference_result.sum(), + nnp_input.positions, + )[0] + + # --------------------------------------- # + # translation test + # set up input + nnp_input = single_batch_with_batchsize( + batch_size=64, dataset_name="QM9", local_cache_dir=str(prep_temp_dir) + ).nnp_input.to_dtype(dtype=precision) + translation_nnp_input = nnp_input.to_dtype(dtype=precision) + translation_nnp_input.positions = translation(translation_nnp_input.positions) + + translation_result = potential(translation_nnp_input)["per_system_energy"] + assert torch.allclose( + translation_result, + reference_result, + atol=atol, + ) + + translation_forces = -torch.autograd.grad( + translation_result.sum(), + translation_nnp_input.positions, + )[0] + + for t, r in zip(translation_forces, reference_forces): + if not torch.allclose(t, r, atol=atol): + print(t, r) + + assert torch.allclose( + translation_forces, + reference_forces, + atol=atol, + ) + + # --------------------------------------- # + # rotation test + # set up input + nnp_input = single_batch_with_batchsize( + batch_size=64, dataset_name="QM9", local_cache_dir=str(prep_temp_dir) + ).nnp_input.to_dtype(dtype=precision) + rotation_input_data = nnp_input.to_dtype(dtype=precision) + rotation_input_data.positions = rotation(rotation_input_data.positions) + rotation_result = potential(rotation_input_data)["per_system_energy"] + + for t, r in zip(rotation_result, reference_result): + if not torch.allclose(t, r, atol=atol): + print(t, r) + + assert torch.allclose( + rotation_result, + reference_result, + atol=atol, + ) + + rotation_forces = -torch.autograd.grad( + rotation_result.sum(), + rotation_input_data.positions, + create_graph=True, + retain_graph=True, + )[0] + + rotate_reference = rotation(reference_forces) + print(rotation_forces, rotate_reference) + assert torch.allclose( + rotation_forces, + rotate_reference, + atol=atol, + ) + + # --------------------------------------- # + # reflection test + # set up input + nnp_input = single_batch_with_batchsize( + batch_size=64, dataset_name="QM9", local_cache_dir=str(prep_temp_dir) + ).nnp_input.to_dtype(dtype=precision) + reflection_input_data = nnp_input.to_dtype(dtype=precision) + reflection_input_data.positions = reflection(reflection_input_data.positions) + reflection_result = potential(reflection_input_data)["per_system_energy"] + reflection_forces = -torch.autograd.grad( + reflection_result.sum(), + reflection_input_data.positions, + create_graph=True, + retain_graph=True, + )[0] + for t, r in zip(reflection_result, reference_result): + if not torch.allclose(t, r, atol=atol): + print(t, r) + + assert torch.allclose( + reflection_result, + reference_result, + atol=atol, + ) + + assert torch.allclose( + reflection_forces, + reflection(reference_forces), + atol=atol, + ) diff --git a/modelforge/tests/test_pt_lightning.py b/modelforge/tests/test_pt_lightning.py index 896d6056..fd941f7a 100644 --- a/modelforge/tests/test_pt_lightning.py +++ b/modelforge/tests/test_pt_lightning.py @@ -1,4 +1,13 @@ -def test_datamodule(): +import pytest + + +@pytest.fixture(scope="session") +def prep_temp_dir(tmp_path_factory): + fn = tmp_path_factory.mktemp("test_pt_temp") + return fn + + +def test_datamodule(prep_temp_dir): # This is an example script that trains an implemented model on the QM9 dataset. from modelforge.dataset.dataset import DataModule @@ -7,5 +16,5 @@ def test_datamodule(): dm = DataModule( name="QM9", batch_size=512, + local_cache_dir=str(prep_temp_dir), ) - diff --git a/modelforge/tests/test_remote.py b/modelforge/tests/test_remote.py index e6aaf4bd..545b3ede 100644 --- a/modelforge/tests/test_remote.py +++ b/modelforge/tests/test_remote.py @@ -8,7 +8,7 @@ @pytest.fixture(scope="session") def prep_temp_dir(tmp_path_factory): - fn = tmp_path_factory.mktemp("remote_test") + fn = tmp_path_factory.mktemp("test_remote_temp") return fn @@ -27,18 +27,20 @@ def test_is_url(): def test_download_from_url(prep_temp_dir): - url = "https://raw.githubusercontent.com/choderalab/modelforge/e3e65e15e23ccc55d03dd7abb4b9add7a7dd15c3/modelforge/modelforge.py" - checksum = "66ec18ca5db3df5791ff1ffc584363a8" + url = "https://zenodo.org/records/3401581/files/PTC-CMC/atools_ml-v0.1.zip" + checksum = "194cde222565dca8657d8521e5df1fd8" + + name = "atools_ml-v0.1.zip" # Download the file download_from_url( url, md5_checksum=checksum, output_path=str(prep_temp_dir), - output_filename="modelforge.py", + output_filename=name, force_download=True, ) - file_name_path = str(prep_temp_dir) + "/modelforge.py" + file_name_path = str(prep_temp_dir) + f"/{name}" assert os.path.isfile(file_name_path) # create a dummy document to test the case where @@ -51,11 +53,11 @@ def test_download_from_url(prep_temp_dir): url, md5_checksum=checksum, output_path=str(prep_temp_dir), - output_filename="modelforge.py", + output_filename=name, force_download=False, ) - file_name_path = str(prep_temp_dir) + "/modelforge.py" + file_name_path = str(prep_temp_dir) + f"/{name}" assert os.path.isfile(file_name_path) # let us change the expected checksum to cause a failure @@ -66,59 +68,7 @@ def test_download_from_url(prep_temp_dir): url, md5_checksum="checksum_garbage", output_path=str(prep_temp_dir), - output_filename="modelforge.py", - force_download=True, - ) - - -def test_download_from_figshare(prep_temp_dir): - url = "https://figshare.com/ndownloader/files/22247589" - name = download_from_figshare( - url=url, - md5_checksum="c1459c5ddce7bb94800032aa3d04788e", - output_path=str(prep_temp_dir), - force_download=True, - ) - - file_name_path = str(prep_temp_dir) + f"/{name}" - assert os.path.isfile(file_name_path) - - # create a dummy document to test the case where - # the checksum doesn't match so it will redownload - with open(file_name_path, "w") as f: - f.write("dummy document") - - # This will force a download because the checksum doesn't match - url = "https://figshare.com/ndownloader/files/22247589" - name = download_from_figshare( - url=url, - md5_checksum="c1459c5ddce7bb94800032aa3d04788e", - output_path=str(prep_temp_dir), - force_download=False, - ) - - file_name_path = str(prep_temp_dir) + f"/{name}" - assert os.path.isfile(file_name_path) - - # the length of this file isn't listed in the headers - # this will check to make sure we can handle this case - url = "https://figshare.com/ndownloader/files/30975751" - name = download_from_figshare( - url=url, - md5_checksum="efa40abff1f71c121f6f0d444c18d5b3", - output_path=str(prep_temp_dir), - force_download=True, - ) - - file_name_path = str(prep_temp_dir) + f"/{name}" - assert os.path.isfile(file_name_path) - - with pytest.raises(Exception): - url = "https://choderalab.com/ndownloader/files/22247589" - name = download_from_figshare( - url=url, - md5_checksum="c1459c5ddce7bb94800032aa3d04788e", - output_path=str(prep_temp_dir), + output_filename=name, force_download=True, ) @@ -134,56 +84,16 @@ def test_fetch_record_id(): fetch_url_from_doi(doi="10.5281/zenodo.3588339", timeout=0.0000000000001) -def test_download_from_zenodo(prep_temp_dir): - url = "https://zenodo.org/records/3401581/files/PTC-CMC/atools_ml-v0.1.zip" - zenodo_checksum = "194cde222565dca8657d8521e5df1fd8" - name = download_from_zenodo( - url=url, - md5_checksum=zenodo_checksum, - output_path=str(prep_temp_dir), - force_download=True, - ) - - file_name_path = str(prep_temp_dir) + f"/{name}" - assert os.path.isfile(file_name_path) - - # create a dummy document to test the case where - # the checksum doesn't match so it will redownload - with open(file_name_path, "w") as f: - f.write("dummy document") - - # make sure that we redownload the file because the checksum of the - # existing file doesn't match - url = "https://zenodo.org/records/3401581/files/PTC-CMC/atools_ml-v0.1.zip" - zenodo_checksum = "194cde222565dca8657d8521e5df1fd8" - name = download_from_zenodo( - url=url, - md5_checksum=zenodo_checksum, - output_path=str(prep_temp_dir), - force_download=False, - ) - - file_name_path = str(prep_temp_dir) + f"/{name}" - assert os.path.isfile(file_name_path) - - with pytest.raises(Exception): - url = "https://choderalab.com/22247589" - name = download_from_zenodo( - url=url, - md5_checksum=zenodo_checksum, - output_path=str(prep_temp_dir), - force_download=True, - ) - - def test_md5_calculation(prep_temp_dir): url = "https://zenodo.org/records/3401581/files/PTC-CMC/atools_ml-v0.1.zip" zenodo_checksum = "194cde222565dca8657d8521e5df1fd8" - name = download_from_zenodo( + name = "atools_ml-v0.1.zip" + download_from_url( url=url, md5_checksum=zenodo_checksum, output_path=str(prep_temp_dir), + output_filename=name, force_download=True, ) @@ -196,9 +106,10 @@ def test_md5_calculation(prep_temp_dir): with pytest.raises(Exception): bad_checksum = "294badmd5checksumthatwontwork9de" - name = download_from_zenodo( + download_from_url( url=url, md5_checksum=bad_checksum, output_path=str(prep_temp_dir), + output_filename=name, force_download=True, ) diff --git a/modelforge/tests/test_representation.py b/modelforge/tests/test_representation.py new file mode 100644 index 00000000..4ae23cd4 --- /dev/null +++ b/modelforge/tests/test_representation.py @@ -0,0 +1,500 @@ +import torch +import pytest + + +def test_radial_symmetry_function_implementation(): + """ + Test the Radial Symmetry function implementation. + """ + import torch + from openff.units import unit + import numpy as np + from modelforge.potential.representation import ( + CosineAttenuationFunction, + GaussianRadialBasisFunctionWithScaling, + ) + + cutoff_module = CosineAttenuationFunction( + cutoff=unit.Quantity(5.0, unit.angstrom).to(unit.nanometer).m + ) + + class RadialSymmetryFunctionTest(GaussianRadialBasisFunctionWithScaling): + @staticmethod + def calculate_radial_basis_centers( + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, + ): + centers = torch.linspace( + _min_distance_in_nanometer, + _max_distance_in_nanometer, + number_of_radial_basis_functions, + dtype=dtype, + ) + return centers + + @staticmethod + def calculate_radial_scale_factor( + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, + ): + scale_factors = torch.full( + (number_of_radial_basis_functions,), + (_min_distance_in_nanometer - _max_distance_in_nanometer) + / number_of_radial_basis_functions, + ) + scale_factors = (scale_factors * -15_000) ** -0.5 + return scale_factors + + RSF = RadialSymmetryFunctionTest( + number_of_radial_basis_functions=18, + max_distance=unit.Quantity(5.0, unit.angstrom).to(unit.nanometer).m, + ) + # test a single distance + d_ij = torch.tensor([[0.2]]) + radial_expension = RSF(d_ij) + + expected_output = np.array( + [ + 5.7777413e-08, + 5.4214674e-06, + 2.4740110e-04, + 5.4905377e-03, + 5.9259072e-02, + 3.1104434e-01, + 7.9399312e-01, + 9.8568588e-01, + 5.9509689e-01, + 1.7472850e-01, + 2.4949821e-02, + 1.7326004e-03, + 5.8513560e-05, + 9.6104134e-07, + 7.6763511e-09, + 2.9819147e-11, + 5.6333109e-14, + 5.1755549e-17, + ], + dtype=np.float32, + ) + + assert np.allclose(radial_expension.numpy().flatten(), expected_output, rtol=1e-3) + + # test multiple distances with cutoff + d_ij = torch.tensor([[r] for r in np.linspace(0, 0.5, 10)]) + radial_expension = RSF(d_ij) * cutoff_module(d_ij) + + expected_output = np.array( + [ + [ + 1.00000000e00, + 6.97370611e-01, + 2.36512753e-01, + 3.90097089e-02, + 3.12909145e-03, + 1.22064879e-04, + 2.31574554e-06, + 2.13657562e-08, + 9.58678574e-11, + 2.09196141e-13, + 2.22005077e-16, + 1.14577532e-19, + 2.87583090e-23, + 3.51038337e-27, + 2.08388175e-31, + 6.01615362e-36, + 8.44679753e-41, + 5.76756600e-46, + ], + [ + 2.68038176e-01, + 7.29490887e-01, + 9.65540222e-01, + 6.21510012e-01, + 1.94559846e-01, + 2.96200218e-02, + 2.19303227e-03, + 7.89645189e-05, + 1.38275834e-06, + 1.17757010e-08, + 4.87703136e-11, + 9.82316969e-14, + 9.62221521e-17, + 4.58380155e-20, + 1.06194951e-23, + 1.19649050e-27, + 6.55604552e-32, + 1.74703654e-36, + ], + [ + 5.15165267e-03, + 5.47178933e-02, + 2.82643788e-01, + 7.10030194e-01, + 8.67443988e-01, + 5.15386799e-01, + 1.48919812e-01, + 2.09266151e-02, + 1.43012111e-03, + 4.75305832e-05, + 7.68248035e-07, + 6.03888967e-09, + 2.30855409e-11, + 4.29190731e-14, + 3.88050222e-17, + 1.70629005e-20, + 3.64875837e-24, + 3.79458837e-28, + ], + [ + 7.05512776e-06, + 2.92447055e-04, + 5.89544925e-03, + 5.77981439e-02, + 2.75573882e-01, + 6.38983424e-01, + 7.20556963e-01, + 3.95161266e-01, + 1.05392022e-01, + 1.36699807e-02, + 8.62294776e-04, + 2.64527563e-05, + 3.94651201e-07, + 2.86340809e-09, + 1.01036987e-11, + 1.73382336e-14, + 1.44696036e-17, + 5.87267193e-21, + ], + [ + 6.79841545e-10, + 1.09978970e-07, + 8.65244557e-06, + 3.31051436e-04, + 6.15997825e-03, + 5.57430086e-02, + 2.45317579e-01, + 5.25042257e-01, + 5.46496226e-01, + 2.76635027e-01, + 6.81011682e-02, + 8.15322217e-03, + 4.74713206e-04, + 1.34419004e-05, + 1.85104660e-07, + 1.23965647e-09, + 4.03750130e-12, + 6.39515861e-15, + ], + [ + 4.50275565e-15, + 2.84275808e-12, + 8.72828077e-10, + 1.30330158e-07, + 9.46429271e-06, + 3.34240505e-04, + 5.74059467e-03, + 4.79492711e-02, + 1.94775558e-01, + 3.84781601e-01, + 3.69675978e-01, + 1.72725113e-01, + 3.92479574e-02, + 4.33716512e-03, + 2.33089213e-04, + 6.09208166e-06, + 7.74348707e-08, + 4.78668403e-10, + ], + [ + 1.95755731e-21, + 4.82320349e-18, + 5.77941614e-15, + 3.36790471e-12, + 9.54470642e-10, + 1.31550705e-07, + 8.81760261e-06, + 2.87432047e-04, + 4.55666577e-03, + 3.51307040e-02, + 1.31720486e-01, + 2.40185684e-01, + 2.12994423e-01, + 9.18579329e-02, + 1.92660386e-02, + 1.96514909e-03, + 9.74823310e-05, + 2.35170925e-06, + ], + [ + 5.02685557e-29, + 4.83367095e-25, + 2.26039866e-21, + 5.14067897e-18, + 5.68568509e-15, + 3.05825078e-12, + 7.99999982e-10, + 1.01773376e-07, + 6.29659424e-06, + 1.89454631e-04, + 2.77224261e-03, + 1.97280685e-02, + 6.82755520e-02, + 1.14914067e-01, + 9.40607618e-02, + 3.74430407e-02, + 7.24871542e-03, + 6.82461743e-04, + ], + [ + 5.43174696e-38, + 2.03835481e-33, + 3.72003749e-29, + 3.30173621e-25, + 1.42516199e-21, + 2.99167362e-18, + 3.05415190e-15, + 1.51633227e-12, + 3.66121676e-10, + 4.29917177e-08, + 2.45510656e-06, + 6.81841165e-05, + 9.20923191e-04, + 6.04910223e-03, + 1.93234986e-02, + 3.00198084e-02, + 2.26807491e-02, + 8.33362963e-03, + ], + [ + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + 0.00000000e00, + ], + ] + ) + + assert np.allclose(radial_expension.numpy(), expected_output, rtol=1e-3) + + +def test_schnet_rbf(): + """ + Test the SchnetRadialBasisFunction class. + """ + from modelforge.potential.representation import SchnetRadialBasisFunction + + # Test parameters + distances = torch.tensor([[0.5], [1.0], [1.5]], dtype=torch.float32) / 10 + number_of_radial_basis_functions = 3 + max_distance = 2.0 / 10 + min_distance = 0.0 + dtype = torch.float32 + + # Instantiate the RBF + rbf = SchnetRadialBasisFunction( + number_of_radial_basis_functions=number_of_radial_basis_functions, + max_distance=max_distance, + min_distance=min_distance, + dtype=dtype, + trainable_centers_and_scale_factors=False, + ) + + # Compute expected outputs + centers = rbf.radial_basis_centers # Shape: [number_of_radial_basis_functions] + scale_factors = rbf.radial_scale_factor # Shape: [number_of_radial_basis_functions] + + # Expand dimensions for broadcasting + distances_expanded = distances # Shape: [number_of_pairs, 1] + centers_expanded = centers.unsqueeze( + 0 + ) # Shape: [1, number_of_radial_basis_functions] + scale_factors_expanded = scale_factors.unsqueeze( + 0 + ) # Shape: [1, number_of_radial_basis_functions] + + # Calculate nondimensionalized distances and expected outputs + diff = distances_expanded - centers_expanded + nondim_distances = diff / scale_factors_expanded + expected_output = torch.exp(-(nondim_distances**2)) + + # Get actual outputs + actual_output = rbf(distances) + + # Assertions + assert actual_output.shape == expected_output.shape, "Output shape mismatch" + assert torch.allclose( + actual_output, expected_output, atol=1e-6 + ), "Outputs do not match expected values for SchnetRadialBasisFunction" + + +def test_ani_rbf(): + """ + Test the AniRadialBasisFunction class. + """ + from modelforge.potential.representation import AniRadialBasisFunction + + # Test parameters + distances = torch.tensor([[0.5], [1.0], [1.5]], dtype=torch.float32) + number_of_radial_basis_functions = 3 + max_distance = 2.0 + min_distance = 0.0 + dtype = torch.float32 + + # Instantiate the RBF + rbf = AniRadialBasisFunction( + number_of_radial_basis_functions=number_of_radial_basis_functions, + max_distance=max_distance, + min_distance=min_distance, + dtype=dtype, + trainable_centers_and_scale_factors=False, + ) + + # Compute expected outputs + centers = rbf.radial_basis_centers # Shape: [number_of_radial_basis_functions] + scale_factors = rbf.radial_scale_factor # Shape: [number_of_radial_basis_functions] + + # Expand dimensions for broadcasting + distances_expanded = distances # Shape: [number_of_pairs, 1] + centers_expanded = centers.unsqueeze( + 0 + ) # Shape: [1, number_of_radial_basis_functions] + scale_factors_expanded = scale_factors.unsqueeze( + 0 + ) # Shape: [1, number_of_radial_basis_functions] + + # Calculate nondimensionalized distances and expected outputs + diff = distances_expanded - centers_expanded + nondim_distances = diff / scale_factors_expanded + expected_output = 0.25 * torch.exp(-(nondim_distances**2)) # Include prefactor + + # Get actual outputs + actual_output = rbf(distances) + + # Assertions + assert actual_output.shape == expected_output.shape, "Output shape mismatch" + assert torch.allclose( + actual_output, expected_output, atol=1e-6 + ), "Outputs do not match expected values for AniRadialBasisFunction" + + +def test_physnet_rbf(): + """ + Test the PhysNetRadialBasisFunction class. + """ + from modelforge.potential.representation import PhysNetRadialBasisFunction + + # Test parameters + distances = torch.tensor([[0.5], [1.0], [1.5]], dtype=torch.float32) / 10 + number_of_radial_basis_functions = 3 + max_distance = 2.0 / 10 + min_distance = 0.0 + alpha = 0.1 + dtype = torch.float32 + + # Instantiate the RBF + rbf = PhysNetRadialBasisFunction( + number_of_radial_basis_functions=number_of_radial_basis_functions, + max_distance=max_distance, + min_distance=min_distance, + alpha=alpha, + dtype=dtype, + trainable_centers_and_scale_factors=False, + ) + + # Compute expected outputs + centers = rbf.radial_basis_centers # Unitless centers + scale_factors = rbf.radial_scale_factor # Unitless scale factors + + # Expand dimensions for broadcasting + distances_expanded = distances # Shape: [number_of_pairs, 1] + centers_expanded = centers.unsqueeze( + 0 + ) # Shape: [1, number_of_radial_basis_functions] + scale_factors_expanded = scale_factors.unsqueeze( + 0 + ) # Shape: [1, number_of_radial_basis_functions] + + # Nondimensionalization as per PhysNet + nondim_distances = ( + torch.exp((-distances + min_distance) / alpha) - centers_expanded + ) / scale_factors_expanded + expected_output = torch.exp(-(nondim_distances**2)) + + # Get actual outputs + actual_output = rbf(distances) + + # Assertions + assert actual_output.shape == expected_output.shape, "Output shape mismatch" + assert torch.allclose( + actual_output, expected_output, atol=1e-6 + ), "Outputs do not match expected values for PhysNetRadialBasisFunction" + + +def test_tensornet_rbf(): + """ + Test the TensorNetRadialBasisFunction class. + """ + from modelforge.potential.representation import TensorNetRadialBasisFunction + + # Test parameters + distances = torch.tensor([[0.5], [1.0], [1.5]], dtype=torch.float32) + number_of_radial_basis_functions = 3 + max_distance = 2.0 + min_distance = 0.0 + dtype = torch.float32 + + # Instantiate the RBF + rbf = TensorNetRadialBasisFunction( + number_of_radial_basis_functions=number_of_radial_basis_functions, + max_distance=max_distance, + min_distance=min_distance, + dtype=dtype, + trainable_centers_and_scale_factors=False, + ) + + # Compute expected outputs + centers = rbf.radial_basis_centers # Unitless centers + scale_factors = rbf.radial_scale_factor # Unitless scale factors + alpha = 0.1 # As per TensorNet implementation + + # Expand dimensions for broadcasting + distances_expanded = distances # Shape: [number_of_pairs, 1] + centers_expanded = centers.unsqueeze( + 0 + ) # Shape: [1, number_of_radial_basis_functions] + scale_factors_expanded = scale_factors.unsqueeze( + 0 + ) # Shape: [1, number_of_radial_basis_functions] + + # Nondimensionalization as per TensorNet + nondim_distances = ( + torch.exp((-distances + min_distance) / alpha) - centers_expanded + ) / scale_factors_expanded + expected_output = torch.exp(-(nondim_distances**2)) + + # Get actual outputs + actual_output = rbf(distances) + + # Assertions + assert actual_output.shape == expected_output.shape, "Output shape mismatch" + assert torch.allclose( + actual_output, expected_output, atol=1e-6 + ), "Outputs do not match expected values for TensorNetRadialBasisFunction" diff --git a/modelforge/tests/test_sake.py b/modelforge/tests/test_sake.py index b69b1e3d..907b52bf 100644 --- a/modelforge/tests/test_sake.py +++ b/modelforge/tests/test_sake.py @@ -1,53 +1,56 @@ import os +from sys import platform -import jax.random import jax.numpy as jnp +import jax.random +import numpy as onp import pytest +import sake as reference_sake import torch -import numpy as onp -from modelforge.potential.sake import SAKE, SAKEInteraction -import sake as reference_sake -from sys import platform +from modelforge.potential.sake import SAKEInteraction +from modelforge.tests.helper_functions import setup_potential_for_test ON_MAC = platform == "darwin" +@pytest.fixture(scope="session") +def prep_temp_dir(tmp_path_factory): + fn = tmp_path_factory.mktemp("test_sake_temp") + return fn + + def test_init(): """Test initialization of the SAKE neural network potential.""" - from modelforge.tests.test_models import load_configs - # read default parameters - config = load_configs(f"sake", "qm9") - - # initialize model - sake = SAKE( - **config["potential"]["core_parameter"], - postprocessing_parameter=config["potential"]["postprocessing_parameter"], + sake = setup_potential_for_test( + "sake", + "training", + local_cache_dir=str(prep_temp_dir), ) + assert sake is not None, "SAKE model should be initialized." from openff.units import unit -def test_forward(single_batch_with_batchsize_64): +def test_forward(single_batch_with_batchsize, prep_temp_dir): """ Test the forward pass of the SAKE model. """ # get methane input - methane = single_batch_with_batchsize_64.nnp_input - - from modelforge.tests.test_models import load_configs - - # read default parameters - config = load_configs(f"sake", "qm9") + batch = single_batch_with_batchsize( + batch_size=64, dataset_name="QM9", local_cache_dir=str(prep_temp_dir) + ) + methane = batch.nnp_input - sake = SAKE( - **config["potential"]["core_parameter"], - postprocessing_parameter=config["potential"]["postprocessing_parameter"], + sake = setup_potential_for_test( + "sake", + "training", + local_cache_dir=str(prep_temp_dir), ) - energy = sake(methane)["per_molecule_energy"] + energy = sake(methane)["per_system_energy"] nr_of_mols = methane.atomic_subsystem_indices.unique().shape[0] assert ( @@ -70,10 +73,10 @@ def test_interaction_forward(): nr_coefficients=23, nr_heads=29, activation=torch.nn.ReLU(), - cutoff=(5.0 * unit.angstrom), + maximum_interaction_radius=(5.0 * unit.angstrom).to(unit.nanometer).m, number_of_radial_basis_functions=53, epsilon=1e-5, - scale_factor=(1.0 * unit.nanometer), + scale_factor=(1.0 * unit.nanometer).m, ) h = torch.randn(nr_atoms, nr_atom_basis) x = torch.randn(nr_atoms, geometry_basis) @@ -87,67 +90,64 @@ def test_interaction_forward(): @pytest.mark.parametrize("eq_atol", [3e-1]) @pytest.mark.parametrize("h_atol", [8e-2]) -def test_layer_equivariance(h_atol, eq_atol, single_batch_with_batchsize_64): - import torch - from modelforge.potential.sake import SAKE +def test_layer_equivariance( + h_atol, eq_atol, single_batch_with_batchsize, prep_temp_dir +): from dataclasses import replace + import torch + # Model parameters - nr_atom_basis = 11 torch.manual_seed(1884) - # define a rotation matrix in 3D that rotates by 90 degrees around the z-axis - # (clockwise when looking along the z-axis towards the origin) + # define a rotation matrix in 3D that rotates by 90 degrees around the + # z-axis (clockwise when looking along the z-axis towards the origin) rotation_matrix = torch.tensor([[0.0, 1.0, 0.0], [-1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]) - from modelforge.tests.test_models import load_configs - - config = load_configs(f"sake", "qm9") - # Extract parameters - core_parameter = config["potential"]["core_parameter"] - core_parameter["number_of_atom_features"] = nr_atom_basis - sake = SAKE( - **core_parameter, - postprocessing_parameter=config["potential"]["postprocessing_parameter"], + sake = setup_potential_for_test( + "sake", + "training", + local_cache_dir=str(prep_temp_dir), ) # get methane input - methane = single_batch_with_batchsize_64.nnp_input - perturbed_methane_input = replace(methane) - perturbed_methane_input.positions = torch.matmul(methane.positions, rotation_matrix) + nnp_input = single_batch_with_batchsize( + batch_size=64, dataset_name="QM9", local_cache_dir=str(prep_temp_dir) + ).nnp_input + ref_nnp_input = single_batch_with_batchsize( + batch_size=64, dataset_name="QM9", local_cache_dir=str(prep_temp_dir) + ).nnp_input + + nnp_input.positions = torch.matmul(nnp_input.positions, rotation_matrix) # prepare reference and perturbed inputs - pairlist_output = sake.input_preparation.prepare_inputs(methane) - reference_prepared_input = sake.core_module._model_specific_input_preparation( - methane, pairlist_output - ) - reference_v_torch = torch.randn_like(reference_prepared_input.positions) + neighborlist = sake.neighborlist(nnp_input) + reference_v_torch = torch.randn_like(nnp_input.positions) - pairlist_output = sake.input_preparation.prepare_inputs(perturbed_methane_input) - perturbed_prepared_input = sake.core_module._model_specific_input_preparation( - perturbed_methane_input, pairlist_output - ) perturbed_v_torch = torch.matmul(reference_v_torch, rotation_matrix) + emedding = torch.nn.Embedding(101, 11) + atomic_embedding = emedding(nnp_input.atomic_numbers) + ( reference_h_out_torch, reference_x_out_torch, reference_v_out_torch, - ) = sake.core_module.interaction_modules[0]( - reference_prepared_input.atomic_embedding, - reference_prepared_input.positions, + ) = sake.core_network.interaction_modules[0]( + atomic_embedding, + ref_nnp_input.positions, reference_v_torch, - reference_prepared_input.pair_indices, + neighborlist.pair_indices, ) ( perturbed_h_out_torch, perturbed_x_out_torch, perturbed_v_out_torch, - ) = sake.core_module.interaction_modules[0]( - perturbed_prepared_input.atomic_embedding, - perturbed_prepared_input.positions, + ) = sake.core_network.interaction_modules[0]( + atomic_embedding, + nnp_input.positions, perturbed_v_torch, - perturbed_prepared_input.pair_indices, + neighborlist.pair_indices, ) # x and v are equivariant, h is invariant @@ -165,7 +165,7 @@ def test_layer_equivariance(h_atol, eq_atol, single_batch_with_batchsize_64): def make_reference_equivalent_sake_interaction(out_features, hidden_features, nr_heads): - cutoff = 5.0 * unit.angstrom + radial_max_distance = unit.Quantity(5.0, unit.angstrom) # Define the modelforge layer mf_sake_block = SAKEInteraction( nr_atom_basis=out_features, @@ -178,10 +178,10 @@ def make_reference_equivalent_sake_interaction(out_features, hidden_features, nr nr_coefficients=(nr_heads * hidden_features), nr_heads=nr_heads, activation=torch.nn.SiLU(), - cutoff=cutoff, + maximum_interaction_radius=radial_max_distance.to(unit.nanometer).m, number_of_radial_basis_functions=50, epsilon=1e-5, - scale_factor=(1.0 * unit.nanometer), + scale_factor=unit.Quantity(1.0, unit.nanometer).to(unit.nanometer).m, ) # Define the reference layer @@ -223,32 +223,31 @@ def make_equivalent_pairlist_mask(key, nr_atoms, nr_pairs, include_self_pairs): def test_radial_symmetry_function_against_reference(): - from modelforge.potential.utils import ( - PhysNetRadialBasisFunction, - ) from sake.utils import ExpNormalSmearing as RefExpNormalSmearing + from modelforge.potential import PhysNetRadialBasisFunction + nr_atoms = 1 number_of_radial_basis_functions = 10 - cutoff_upper = 6.0 * unit.nanometer - cutoff_lower = 2.0 * unit.nanometer + cutoff_upper = unit.Quantity(6.0, unit.nanometer) + cutoff_lower = unit.Quantity(2.0, unit.nanometer) radial_symmetry_function_module = PhysNetRadialBasisFunction( number_of_radial_basis_functions=number_of_radial_basis_functions, - max_distance=cutoff_upper, - min_distance=cutoff_lower, + max_distance=cutoff_upper.to(unit.nanometer).m, + min_distance=cutoff_lower.to(unit.nanometer).m, dtype=torch.float32, ) ref_radial_basis_module = RefExpNormalSmearing( num_rbf=number_of_radial_basis_functions, - cutoff_upper=cutoff_upper.m, - cutoff_lower=cutoff_lower.m, + cutoff_upper=cutoff_upper.to(unit.nanometer).m, + cutoff_lower=cutoff_lower.to(unit.nanometer).m, ) key = jax.random.PRNGKey(1884) # Generate random input data in JAX d_ij_jax = jax.random.uniform(key, (nr_atoms, nr_atoms, 1)) - d_ij = torch.from_numpy(onp.array(d_ij_jax)).reshape((nr_atoms ** 2, 1)) + d_ij = torch.from_numpy(onp.array(d_ij_jax)).reshape((nr_atoms**2, 1)) mf_rbf = radial_symmetry_function_module(d_ij) variables = ref_radial_basis_module.init(key, d_ij_jax) @@ -256,10 +255,14 @@ def test_radial_symmetry_function_against_reference(): assert torch.allclose( torch.from_numpy(onp.array(variables["params"]["means"])), radial_symmetry_function_module.radial_basis_centers.detach().T, + atol=1e-1, + rtol=1e-1, ) assert torch.allclose( torch.from_numpy(onp.array(variables["params"]["betas"])) ** -0.5, radial_symmetry_function_module.radial_scale_factor.detach().T, + atol=1e-2, + rtol=1e-2, ) ref_rbf = ref_radial_basis_module.apply(variables, d_ij_jax) @@ -267,7 +270,7 @@ def test_radial_symmetry_function_against_reference(): assert torch.allclose( mf_rbf, torch.from_numpy(onp.array(ref_rbf)).reshape( - nr_atoms ** 2, number_of_radial_basis_functions + nr_atoms**2, number_of_radial_basis_functions ), ) @@ -407,195 +410,31 @@ def test_sake_layer_against_reference(include_self_pairs, v_is_none): ) -def test_model_against_reference(single_batch_with_batchsize_1): - nr_heads = 5 - nr_atom_basis = 11 - max_Z = 13 - key = jax.random.PRNGKey(1884) - torch.manual_seed(1884) - nr_interaction_blocks = 3 - cutoff = 5.0 * unit.angstrom - - mf_sake = SAKE( - max_Z=max_Z, - number_of_atom_features=nr_atom_basis, - number_of_interaction_modules=nr_interaction_blocks, - number_of_spatial_attention_heads=nr_heads, - cutoff=cutoff, - number_of_radial_basis_functions=50, - epsilon=1e-8, - postprocessing_parameter={ - "per_atom_energy": { - "normalize": True, - "from_atom_to_molecule_reduction": True, - "keep_per_atom_property": True, - } - }, - ) - - ref_sake = reference_sake.models.DenseSAKEModel( - hidden_features=nr_atom_basis, - out_features=1, - depth=nr_interaction_blocks, - n_heads=nr_heads, - cutoff=None, - ) - - # get methane input - methane = single_batch_with_batchsize_1.nnp_input - pairlist_output = mf_sake.input_preparation.prepare_inputs(methane) - prepared_methane = mf_sake.core_module._model_specific_input_preparation( - methane, pairlist_output - ) - - mask = jnp.zeros( - (prepared_methane.number_of_atoms, prepared_methane.number_of_atoms) - ) - for i in range(prepared_methane.pair_indices.shape[1]): - mask = mask.at[ - prepared_methane.pair_indices[0, i].item(), - prepared_methane.pair_indices[1, i].item(), - ].set(1) - - h = jax.nn.one_hot(prepared_methane.atomic_numbers.detach().numpy(), max_Z) - x = prepared_methane.positions.detach().numpy() - variables = ref_sake.init(key, h, x, mask=mask) - - variables["params"]["embedding_in"]["kernel"] = ( - mf_sake.core_module.embedding.weight.detach().numpy().T - ) - variables["params"]["embedding_in"]["bias"] = ( - mf_sake.core_module.embedding.bias.detach().numpy().T - ) - variables["params"]["embedding_out"]["layers_0"]["kernel"] = ( - mf_sake.core_module.energy_layer[0].weight.detach().numpy().T - ) - variables["params"]["embedding_out"]["layers_0"]["bias"] = ( - mf_sake.core_module.energy_layer[0].bias.detach().numpy().T - ) - variables["params"]["embedding_out"]["layers_2"]["kernel"] = ( - mf_sake.core_module.energy_layer[2].weight.detach().numpy().T - ) - variables["params"]["embedding_out"]["layers_2"]["bias"] = ( - mf_sake.core_module.energy_layer[2].bias.detach().numpy().T - ) - layers = ( - (layer_name, variables["params"][layer_name]) - for layer_name in variables["params"].keys() - if layer_name.startswith("d") - ) - for (layer_name, layer), mf_sake_block in zip( - layers, mf_sake.core_module.interaction_modules.children() - ): - layer["edge_model"]["kernel"]["betas"] = ( - mf_sake_block.radial_symmetry_function_module.radial_scale_factor.detach() - .numpy() - .T - ) - layer["edge_model"]["kernel"]["means"] = ( - mf_sake_block.radial_symmetry_function_module.radial_basis_centers.detach() - .numpy() - .T - ) - layer["edge_model"]["mlp_in"]["bias"] = ( - mf_sake_block.edge_mlp_in.bias.detach().numpy().T - ) - layer["edge_model"]["mlp_in"]["kernel"] = ( - mf_sake_block.edge_mlp_in.weight.detach().numpy().T - ) - layer["edge_model"]["mlp_out"]["layers_0"]["bias"] = ( - mf_sake_block.edge_mlp_out[0].bias.detach().numpy().T - ) - layer["edge_model"]["mlp_out"]["layers_0"]["kernel"] = ( - mf_sake_block.edge_mlp_out[0].weight.detach().numpy().T - ) - layer["edge_model"]["mlp_out"]["layers_2"]["bias"] = ( - mf_sake_block.edge_mlp_out[1].bias.detach().numpy().T - ) - layer["edge_model"]["mlp_out"]["layers_2"]["kernel"] = ( - mf_sake_block.edge_mlp_out[1].weight.detach().numpy().T - ) - layer["node_mlp"]["layers_0"]["bias"] = ( - mf_sake_block.node_mlp[0].bias.detach().numpy().T - ) - layer["node_mlp"]["layers_0"]["kernel"] = ( - mf_sake_block.node_mlp[0].weight.detach().numpy().T - ) - layer["node_mlp"]["layers_2"]["bias"] = ( - mf_sake_block.node_mlp[1].bias.detach().numpy().T - ) - layer["node_mlp"]["layers_2"]["kernel"] = ( - mf_sake_block.node_mlp[1].weight.detach().numpy().T - ) - layer["post_norm_mlp"]["layers_0"]["bias"] = ( - mf_sake_block.post_norm_mlp[0].bias.detach().numpy().T - ) - layer["post_norm_mlp"]["layers_0"]["kernel"] = ( - mf_sake_block.post_norm_mlp[0].weight.detach().numpy().T - ) - layer["post_norm_mlp"]["layers_2"]["bias"] = ( - mf_sake_block.post_norm_mlp[1].bias.detach().numpy().T - ) - layer["post_norm_mlp"]["layers_2"]["kernel"] = ( - mf_sake_block.post_norm_mlp[1].weight.detach().numpy().T - ) - layer["semantic_attention_mlp"]["layers_0"]["bias"] = ( - mf_sake_block.semantic_attention_mlp.bias.detach().numpy().T - ) - layer["semantic_attention_mlp"]["layers_0"]["kernel"] = ( - mf_sake_block.semantic_attention_mlp.weight.detach().numpy().T - ) - - if layer_name != "d0": - layer["velocity_mlp"]["layers_0"]["kernel"] = ( - mf_sake_block.velocity_mlp[0].weight.detach().numpy().T - ) - layer["velocity_mlp"]["layers_0"]["bias"] = ( - mf_sake_block.velocity_mlp[0].bias.detach().numpy().T - ) - layer["velocity_mlp"]["layers_2"]["kernel"] = ( - mf_sake_block.velocity_mlp[1].weight.detach().numpy().T - ) - layer["v_mixing"]["kernel"] = ( - mf_sake_block.v_mixing_mlp.weight.detach().numpy().T - ) - layer["x_mixing"]["layers_0"]["kernel"] = ( - mf_sake_block.x_mixing_mlp.weight.detach().numpy().T - ) - - # jax.tree_util.tree_map_with_path(lambda path, leaf: print(path, leaf.shape), variables) - - mf_out = mf_sake(methane) - ref_out = ref_sake.apply(variables, h, x, mask=mask)[0].sum(-2) - # ref_out is nan, so we can't compare it to the modelforge output - - print(f"{mf_out['per_molecule_energy']=}") - print(f"{ref_out=}") - # assert torch.allclose(mf_out.E, torch.from_numpy(onp.array(ref_out[0]))) +import pytest -def test_model_invariance(single_batch_with_batchsize_1): +def test_model_invariance(single_batch_with_batchsize, prep_temp_dir): from dataclasses import replace - from modelforge.tests.test_models import load_configs - - config = load_configs(f"sake", "qm9") - - # initialize model - model = SAKE( - **config["potential"]["core_parameter"], - postprocessing_parameter=config["potential"]["postprocessing_parameter"], + sake = setup_potential_for_test( + "sake", + "training", + local_cache_dir=str(prep_temp_dir), ) # get methane input - methane = single_batch_with_batchsize_1.nnp_input + methane = single_batch_with_batchsize( + batch_size=1, dataset_name="QM9", local_cache_dir=str(prep_temp_dir) + ).nnp_input + reference_methane = single_batch_with_batchsize( + batch_size=1, dataset_name="QM9", local_cache_dir=str(prep_temp_dir) + ).nnp_input rotation_matrix = torch.tensor([[0.0, 1.0, 0.0], [-1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]) - perturbed_methane_input = replace(methane) - perturbed_methane_input.positions = torch.matmul(methane.positions, rotation_matrix) + methane.positions = torch.matmul(methane.positions, rotation_matrix) - reference_out = model(methane) - perturbed_out = model(perturbed_methane_input) + reference_out = sake(reference_methane) + perturbed_out = sake(methane) assert torch.allclose( - reference_out["per_molecule_energy"], perturbed_out["per_molecule_energy"] + reference_out["per_system_energy"], perturbed_out["per_system_energy"] ) diff --git a/modelforge/tests/test_schnet.py b/modelforge/tests/test_schnet.py index ded1c8f2..46ab67c4 100644 --- a/modelforge/tests/test_schnet.py +++ b/modelforge/tests/test_schnet.py @@ -4,66 +4,47 @@ load_precalculated_schnet_results, setup_single_methane_input, ) +from typing import Optional -def initialize_model( - cutoff: float, - number_of_atom_features: int, - number_of_radial_basis_functions: int, - nr_of_interactions: int, -): - # ------------------------------------ # - # set up the modelforge Painn representation model - # which means that we only want to call the - # _transform_input() method - from modelforge.potential.schnet import SchNet +def setup_schnet_model(potential_seed: Optional[int] = None): + from modelforge.tests.test_potentials import load_configs_into_pydantic_models + from modelforge.potential import NeuralNetworkPotentialFactory - return SchNet( - max_Z=101, - number_of_atom_features=number_of_atom_features, - number_of_interaction_modules=nr_of_interactions, - number_of_radial_basis_functions=number_of_radial_basis_functions, - cutoff=cutoff, - number_of_filters=number_of_atom_features, - shared_interactions=False, - processing_operation=[], - readout_operation=[ - { - "step": "from_atom_to_molecule", - "mode": "sum", - "in": "per_atom_energy", - "index_key": "atomic_subsystem_indices", - "out": "E", - } - ], - ) + # read default parameters + config = load_configs_into_pydantic_models("schnet", "qm9") + # override defaults to match reference implementation in spk + config[ + "potential" + ].core_parameter.featurization.atomic_number.number_of_per_atom_features = 12 + config["potential"].core_parameter.number_of_radial_basis_functions = 5 + config["potential"].core_parameter.number_of_filters = 12 + + model = NeuralNetworkPotentialFactory.generate_trainer( + potential_parameter=config["potential"], + training_parameter=config["training"], + dataset_parameter=config["dataset"], + runtime_parameter=config["runtime"], + potential_seed=potential_seed, + ).lightning_module.potential + return model def test_init(): """Test initialization of the Schnet model.""" - from modelforge.potential.schnet import SchNet - - from modelforge.tests.test_models import load_configs - - # load default parameters - config = load_configs(f"schnet", "qm9") - # initialize model - schnet = SchNet( - **config["potential"]["core_parameter"], - postprocessing_parameter=config["potential"]["postprocessing_parameter"], - ) + schnet = setup_schnet_model() assert schnet is not None, "Schnet model should be initialized." -def test_compare_representation(): +def test_compare_rbf(): # compare schnetpack RadialSymmetryFunction with modelforge RadialSymmetryFunction - from modelforge.potential.utils import SchnetRadialBasisFunction + from modelforge.potential import SchnetRadialBasisFunction from openff.units import unit # Initialize the RBFs number_of_gaussians = 10 - cutoff = unit.Quantity(5.2, unit.angstrom) - start = unit.Quantity(0.8, unit.angstrom) + cutoff = unit.Quantity(5.2, unit.angstrom).to(unit.nanometer).m + start = unit.Quantity(0.8, unit.angstrom).to(unit.nanometer).m rbf_module = SchnetRadialBasisFunction( number_of_radial_basis_functions=number_of_gaussians, @@ -157,32 +138,11 @@ def test_compare_representation(): ) # NOTE: there is a shape mismatch between the two outputs -def test_compare_forward(): +def test_compare_implementation_against_reference_implementation(): # ---------------------------------------- # # test the implementation of the representation part of the PaiNN model # ---------------------------------------- # - from modelforge.potential.schnet import SchNet - - from modelforge.tests.test_models import load_configs - - # load default parameters - config = load_configs(f"schnet", "qm9") - - # override default parameters - config["potential"]["core_parameter"]["number_of_atom_features"] = 12 - config["potential"]["core_parameter"]["number_of_radial_basis_functions"] = 5 - config["potential"]["core_parameter"]["number_of_filters"] = 12 - - print(f"{config['potential']['core_parameter']}=") - - torch.manual_seed(1234) - - # initialize model - schnet = SchNet( - **config["potential"]["core_parameter"], - postprocessing_parameter=config["potential"]["postprocessing_parameter"], - ).double() - + model = setup_schnet_model(1234).double() # ------------------------------------ # # reference values # generated with schnetpack2.0 @@ -193,113 +153,28 @@ def test_compare_forward(): spk_input = input["spk_methane_input"] model_input = input["modelforge_methane_input"] - schnet.input_preparation._input_checks(model_input) - - pairlist_output = schnet.input_preparation.prepare_inputs(model_input) - prepared_input = schnet.core_module._model_specific_input_preparation( - model_input, pairlist_output - ) - - # ---------------------------------------- # - # test neighborlist and distance - # ---------------------------------------- # - assert torch.allclose(spk_input["_Rij"] / 10, prepared_input.r_ij, atol=1e-4) - assert torch.allclose(spk_input["_idx_i"], prepared_input.pair_indices[0]) - assert torch.allclose(spk_input["_idx_j"], prepared_input.pair_indices[1]) - - # ---------------------------------------- # - # test radial symmetry function - # ---------------------------------------- # - r_ij = spk_input["_Rij"] - d_ij = torch.norm(r_ij, dim=1, keepdim=True) - - reference_phi_ij = torch.tensor( - [ - [0.6828, 0.9920, 0.5302, 0.1043, 0.0075], - [0.6828, 0.9920, 0.5302, 0.1043, 0.0075], - [0.6828, 0.9920, 0.5302, 0.1043, 0.0075], - [0.6828, 0.9920, 0.5302, 0.1043, 0.0075], - [0.6828, 0.9920, 0.5302, 0.1043, 0.0075], - [0.3615, 0.9131, 0.8484, 0.2900, 0.0365], - [0.3615, 0.9130, 0.8484, 0.2900, 0.0365], - [0.3615, 0.9130, 0.8484, 0.2900, 0.0365], - [0.6828, 0.9920, 0.5302, 0.1043, 0.0075], - [0.3615, 0.9131, 0.8484, 0.2900, 0.0365], - [0.3615, 0.9131, 0.8484, 0.2900, 0.0365], - [0.3615, 0.9131, 0.8484, 0.2900, 0.0365], - [0.6828, 0.9920, 0.5302, 0.1043, 0.0075], - [0.3615, 0.9130, 0.8484, 0.2900, 0.0365], - [0.3615, 0.9131, 0.8484, 0.2900, 0.0365], - [0.3615, 0.9131, 0.8484, 0.2900, 0.0365], - [0.6828, 0.9920, 0.5302, 0.1043, 0.0075], - [0.3615, 0.9130, 0.8484, 0.2900, 0.0365], - [0.3615, 0.9131, 0.8484, 0.2900, 0.0365], - [0.3615, 0.9131, 0.8484, 0.2900, 0.0365], - ], - dtype=torch.float64, - ) - calculated_phi_ij = ( - schnet.core_module.schnet_representation_module.radial_symmetry_function_module( - d_ij / 10 - ) - ) # NOTE: converting to nm - - assert torch.allclose(reference_phi_ij.squeeze(1), calculated_phi_ij, atol=1e-3) - # ---------------------------------------- # - # test cutoff - # ---------------------------------------- # - reference_fcut = torch.tensor( - [ - [0.8869], - [0.8869], - [0.8869], - [0.8869], - [0.8869], - [0.7177], - [0.7177], - [0.7177], - [0.8869], - [0.7177], - [0.7177], - [0.7177], - [0.8869], - [0.7177], - [0.7177], - [0.7177], - [0.8869], - [0.7177], - [0.7177], - [0.7177], - ], - dtype=torch.float64, - ) - calculated_fcut = schnet.core_module.schnet_representation_module.cutoff_module( - d_ij / 10 - ) # NOTE: converting to nm - assert torch.allclose(reference_fcut, calculated_fcut, atol=1e-4) - # ---------------------------------------- # # test forward pass # ---------------------------------------- # # reset torch.manual_seed(1234) for i in range(3): - schnet.core_module.interaction_modules[i].intput_to_feature.reset_parameters() + model.core_network.interaction_modules[i].intput_to_feature.reset_parameters() for j in range(2): - schnet.core_module.interaction_modules[i].feature_to_output[ + model.core_network.interaction_modules[i].feature_to_output[ j ].reset_parameters() - schnet.core_module.interaction_modules[i].filter_network[ + model.core_network.interaction_modules[i].filter_network[ j ].reset_parameters() - calculated_results = schnet.core_module.forward(model_input, pairlist_output) + calculated_results = model.compute_core_network_output(model_input) reference_results = load_precalculated_schnet_results() assert ( reference_results["scalar_representation"].shape - == calculated_results["scalar_representation"].shape + == calculated_results["per_atom_scalar_representation"].shape ) scalar_spk = reference_results["scalar_representation"] - scalar_mf = calculated_results["scalar_representation"] + scalar_mf = calculated_results["per_atom_scalar_representation"] assert torch.allclose(scalar_spk, scalar_mf, atol=1e-4) diff --git a/modelforge/tests/test_spk.py b/modelforge/tests/test_spk.py index 619c3d67..2282acd1 100644 --- a/modelforge/tests/test_spk.py +++ b/modelforge/tests/test_spk.py @@ -1,5 +1,6 @@ import torch import pytest +from openff.units import unit @pytest.mark.xfail @@ -41,11 +42,15 @@ def test_compare_radial_symmetry_features(): def setup_spk_painn_representation( - cutoff, nr_atom_basis, number_of_gaussians, nr_of_interactions + cutoff, + nr_atom_basis, + number_of_gaussians, + nr_of_interactions, + maximum_atomic_number, ): # ------------------------------------ # # set up the schnetpack Painn representation model - from schnetpack.nn import GaussianRBF, CosineCutoff + from schnetpack.nn import GaussianRBF, CosineAttenuationFunction from schnetpack.representation import PaiNN as schnetpack_PaiNN from openff.units import unit @@ -56,12 +61,17 @@ def setup_spk_painn_representation( n_atom_basis=nr_atom_basis, n_interactions=nr_of_interactions, radial_basis=radial_basis, - cutoff_fn=CosineCutoff(cutoff.to(unit.angstrom).m), + cutoff_fn=CosineAttenuationFunction(cutoff.to(unit.angstrom).m), + max_z=maximum_atomic_number, ) def setup_modelforge_painn_representation( - cutoff, nr_atom_basis, number_of_gaussians, nr_of_interactions + cutoff, + nr_atom_basis, + number_of_gaussians, + nr_of_interactions, + maximum_atomic_number, ): # ------------------------------------ # # set up the modelforge Painn representation model @@ -71,23 +81,28 @@ def setup_modelforge_painn_representation( from openff.units import unit return mf_PaiNN( - max_Z=100, - number_of_atom_features=nr_atom_basis, + featurization={ + "properties_to_featurize": ["atomic_number"], + "maximum_atomic_number": maximum_atomic_number, + "number_of_per_atom_features": nr_atom_basis, + }, number_of_interaction_modules=nr_of_interactions, number_of_radial_basis_functions=number_of_gaussians, - cutoff=cutoff, + maximum_interaction_radius=cutoff, shared_interactions=False, shared_filters=False, - processing_operation=[], - readout_operation=[ - { - "step": "from_atom_to_molecule", - "mode": "sum", - "in": 'per_atom_energy', - "index_key": "atomic_subsystem_indices", - "out": "E", - } - ], + activation_function="SiLU", + postprocessing_parameter={ + "per_atom_energy": { + "normalize": True, + "from_atom_to_system_reduction": True, + "keep_per_atom_property": True, + }, + "general_postprocessing_operation": { + "calculate_molecular_self_energy": True, + "calculate_atomic_self_energy": False, + }, + }, ) @@ -102,14 +117,23 @@ def test_painn_representation_implementation(): nr_atom_basis = 128 number_of_gaussians = 5 nr_of_interactions = 3 + maximum_atomic_number = 23 torch.manual_seed(1234) schnetpack_painn = setup_spk_painn_representation( - cutoff, nr_atom_basis, number_of_gaussians, nr_of_interactions + cutoff, + nr_atom_basis, + number_of_gaussians, + nr_of_interactions, + maximum_atomic_number, ).double() torch.manual_seed(1234) modelforge_painn = setup_modelforge_painn_representation( - cutoff, nr_atom_basis, number_of_gaussians, nr_of_interactions + cutoff, + nr_atom_basis, + number_of_gaussians, + nr_of_interactions, + maximum_atomic_number, ).double() # ------------------------------------ # # set up the input for the spk Painn model @@ -118,8 +142,8 @@ def test_painn_representation_implementation(): mf_nnp_input = input["modelforge_methane_input"] schnetpack_results = schnetpack_painn(spk_input) - modelforge_painn.input_preparation._input_checks(mf_nnp_input) - pairlist_output = modelforge_painn.input_preparation.prepare_inputs(mf_nnp_input) + modelforge_painn.compute_interacting_pairs._input_checks(mf_nnp_input) + pairlist_output = modelforge_painn.compute_interacting_pairs.forward(mf_nnp_input) pain_nn_input_mf = modelforge_painn.core_module._model_specific_input_preparation( mf_nnp_input, pairlist_output ) @@ -173,9 +197,7 @@ def test_painn_representation_implementation(): spk_input[properties.Z].to(torch.int32), mf_nnp_input.atomic_numbers.squeeze() ) embedding_spk = schnetpack_painn.embedding(spk_input[properties.Z]) - embedding_mf = modelforge_painn.core_module.embedding_module( - mf_nnp_input.atomic_numbers - ) + embedding_mf = modelforge_painn.core_module.featurize_input(mf_nnp_input) assert torch.allclose(embedding_spk, embedding_mf) # ---------------------------------------- # @@ -190,7 +212,7 @@ def test_painn_representation_implementation(): assert torch.allclose(q_spk_initial, q_mf_initial) mu_spk_initial = torch.zeros((spk_qs[0], 3, spk_qs[2])) - mu_mf_initial = torch.zeros((mf_qs[0], 3, mf_qs[2])) + mu_mf_initial = torch.zeros((mf_qs[0], 3, mf_qs[2]), dtype=torch.float64) assert mu_spk_initial.shape == mu_mf_initial.shape # set up the filter and interaction, pass the input and compare the results @@ -233,7 +255,9 @@ def test_painn_representation_implementation(): torch.manual_seed(1234) pair_indices = pain_nn_input_mf.pair_indices filter_list = torch.split( - filters_mf, 3 * modelforge_painn.core_module.number_of_atom_features, dim=-1 + filters_mf, + 3 * modelforge_painn.core_module.representation_module.nr_atom_basis, + dim=-1, ) # test intra-atomic NNP @@ -327,7 +351,9 @@ def test_painn_representation_implementation(): q_spk, mu_spk = mixing(q_spk, mu_spk) mf_filter_list = torch.split( - filters_mf, 3 * modelforge_painn.core_module.number_of_atom_features, dim=-1 + filters_mf, + 3 * modelforge_painn.core_module.representation_module.nr_atom_basis, + dim=-1, ) # q_mf = q_mf_initial # mu_mf = mu_mf_initial @@ -353,7 +379,9 @@ def test_painn_representation_implementation(): filters_spk, 3 * schnetpack_painn.n_atom_basis, dim=-1 ) mf_filter_list = torch.split( - filters_mf, 3 * modelforge_painn.core_module.number_of_atom_features, dim=-1 + filters_mf, + 3 * modelforge_painn.core_module.representation_module.nr_atom_basis, + dim=-1, ) # q_spk = q_spk_initial @@ -398,7 +426,9 @@ def test_painn_representation_implementation(): schnetpack_painn.filter_net.weight, atol=1e-4, ) - modelforge_results = modelforge_painn.core_module.compute_properties(pain_nn_input_mf) + modelforge_results = modelforge_painn.core_module.compute_properties( + pain_nn_input_mf + ) schnetpack_results = schnetpack_painn(spk_input) assert ( @@ -418,11 +448,15 @@ def test_painn_representation_implementation(): def setup_spk_schnet_representation( - cutoff: float, number_of_atom_features: int, n_rbf: int, nr_of_interactions: int + cutoff: unit.Quantity, + number_of_atom_features: int, + n_rbf: int, + nr_of_interactions: int, + maximum_atomic_number: int, ): # ------------------------------------ # # set up the schnetpack Painn representation model - from schnetpack.nn import GaussianRBF, CosineCutoff + from schnetpack.nn import GaussianRBF, CosineAttenuationFunction from schnetpack.representation import SchNet as schnetpack_SchNET from openff.units import unit @@ -431,16 +465,18 @@ def setup_spk_schnet_representation( n_atom_basis=number_of_atom_features, n_interactions=nr_of_interactions, radial_basis=radial_basis, - cutoff_fn=CosineCutoff(cutoff.to(unit.angstrom).m), + max_z=maximum_atomic_number, + cutoff_fn=CosineAttenuationFunction(cutoff.to(unit.angstrom).m), ) @pytest.mark.xfail def setup_mf_schnet_representation( - cutoff: float, + cutoff: unit.Quantity, number_of_atom_features: int, number_of_radial_basis_functions: int, nr_of_interactions: int, + maximum_atomic_number: int, ): # ------------------------------------ # # set up the modelforge Painn representation model @@ -449,23 +485,28 @@ def setup_mf_schnet_representation( from modelforge.potential.schnet import SchNet as mf_SchNET return mf_SchNET( - max_Z=101, - number_of_atom_features=number_of_atom_features, + featurization={ + "properties_to_featurize": ["atomic_number"], + "maximum_atomic_number": maximum_atomic_number, + "number_of_per_atom_features": number_of_atom_features, + }, number_of_interaction_modules=nr_of_interactions, number_of_radial_basis_functions=number_of_radial_basis_functions, - cutoff=cutoff, + maximum_interaction_radius=cutoff, number_of_filters=number_of_atom_features, shared_interactions=False, - processing_operation=[], - readout_operation=[ - { - "step": "from_atom_to_molecule", - "mode": "sum", - "in": 'per_atom_energy', - "index_key": "atomic_subsystem_indices", - "out": "E", - } - ], + activation_function="ShiftedSoftplus", + postprocessing_parameter={ + "per_atom_energy": { + "normalize": True, + "from_atom_to_system_reduction": True, + "keep_per_atom_property": True, + }, + "general_postprocessing_operation": { + "calculate_molecular_self_energy": True, + "calculate_atomic_self_energy": False, + }, + }, ) @@ -480,13 +521,22 @@ def test_schnet_representation_implementation(): number_of_atom_features = 12 n_rbf = 5 nr_of_interactions = 3 + maximum_atomic_number = 23 torch.manual_seed(1234) schnetpack_schnet = setup_spk_schnet_representation( - cutoff, number_of_atom_features, n_rbf, nr_of_interactions + cutoff, + number_of_atom_features, + n_rbf, + nr_of_interactions, + maximum_atomic_number, ).double() torch.manual_seed(1234) modelforge_schnet = setup_mf_schnet_representation( - cutoff, number_of_atom_features, n_rbf, nr_of_interactions + cutoff, + number_of_atom_features, + n_rbf, + nr_of_interactions, + maximum_atomic_number, ).double() # ------------------------------------ # # set up the input for the spk Schnet model @@ -494,9 +544,11 @@ def test_schnet_representation_implementation(): spk_input = input["spk_methane_input"] mf_nnp_input = input["modelforge_methane_input"] - modelforge_schnet.input_preparation._input_checks(mf_nnp_input) + modelforge_schnet.compute_interacting_pairs._input_checks(mf_nnp_input) - pairlist_output = modelforge_schnet.input_preparation.prepare_inputs(mf_nnp_input) + pairlist_output = modelforge_schnet.compute_interacting_pairs.prepare_inputs( + mf_nnp_input + ) schnet_nn_input_mf = ( modelforge_schnet.core_module._model_specific_input_preparation( mf_nnp_input, pairlist_output @@ -541,9 +593,7 @@ def test_schnet_representation_implementation(): schnet_nn_input_mf.atomic_numbers.squeeze(), ) embedding_spk = schnetpack_schnet.embedding(spk_input[properties.Z]) - embedding_mf = modelforge_schnet.core_module.embedding_module( - schnet_nn_input_mf.atomic_numbers - ) + embedding_mf = modelforge_schnet.core_module.featurize_input(schnet_nn_input_mf) assert torch.allclose(embedding_spk, embedding_mf) @@ -631,7 +681,9 @@ def test_schnet_representation_implementation(): assert torch.allclose(v_spk, v_mf) # Check full pass - modelforge_results = modelforge_schnet.core_module.compute_properties(schnet_nn_input_mf) + modelforge_results = modelforge_schnet.core_module.compute_properties( + schnet_nn_input_mf + ) schnetpack_results = schnetpack_schnet(spk_input) assert ( diff --git a/modelforge/tests/test_tensornet.py b/modelforge/tests/test_tensornet.py index b569b495..0df74dfe 100644 --- a/modelforge/tests/test_tensornet.py +++ b/modelforge/tests/test_tensornet.py @@ -1,9 +1,352 @@ -def test_tensornet_init(): +import pytest - # This test is a placeholder - return None +from modelforge.tests.helper_functions import setup_potential_for_test - from modelforge.potential.tensornet import TensorNet - net = TensorNet() - assert net is not None +@pytest.fixture(scope="session") +def prep_temp_dir(tmp_path_factory): + fn = tmp_path_factory.mktemp("test_tensornet") + return fn + + +def test_init(prep_temp_dir): + """Test initialization of the TensorNet model.""" + + # load default parameters + # read default parameters + model = setup_potential_for_test( + use="inference", + potential_seed=42, + potential_name="tensornet", + simulation_environment="JAX", + local_cache_dir=str(prep_temp_dir), + ) + assert model is not None, "TensorNet model should be initialized." + + +@pytest.mark.parametrize("simulation_environment", ["PyTorch", "JAX"]) +def test_forward_with_inference_model( + simulation_environment, single_batch_with_batchsize, prep_temp_dir +): + + nnp_input = single_batch_with_batchsize( + batch_size=32, dataset_name="QM9", local_cache_dir=str(prep_temp_dir) + ).nnp_input + + # load default parameters + model = setup_potential_for_test( + use="inference", + potential_seed=42, + potential_name="tensornet", + simulation_environment=simulation_environment, + use_training_mode_neighborlist=True, + local_cache_dir=str(prep_temp_dir), + ) + + if simulation_environment == "JAX": + from modelforge.jax import convert_NNPInput_to_jax + + model(convert_NNPInput_to_jax(nnp_input)) + else: + model(nnp_input) + + +def test_input(prep_temp_dir): + import torch + from loguru import logger as log + + from modelforge.tests.precalculated_values import ( + prepare_values_for_test_tensornet_input, + ) + + # setup model + model = setup_potential_for_test( + use="inference", + potential_seed=42, + potential_name="tensornet", + simulation_environment="PyTorch", + use_training_mode_neighborlist=True, + local_cache_dir=str(prep_temp_dir), + ) + + from importlib import resources + + from modelforge.tests import data + + # load reference data + reference_data = resources.files(data) / "tensornet_input.pt" + reference_batch = resources.files(data) / "nnp_input.pkl" + import pickle + + mf_input = pickle.load(open(reference_batch, "rb")) + + # calculate pairlist + pairlist_output = model.neighborlist.forward(mf_input) + + # compare to torchmd-net pairlist + if reference_data: + log.warning('Using reference data for "test_input"') + edge_index, edge_weight, edge_vec = torch.load(reference_data) + else: + log.warning('Calculating reference data from "test_input"') + edge_index, edge_weight, edge_vec = prepare_values_for_test_tensornet_input( + mf_input, + seed=0, + ) + + # reshape and compare + pair_indices = pairlist_output.pair_indices.t() + edge_index = edge_index.t() + for _, pair_index in enumerate(pair_indices): + idx = ((edge_index == pair_index).sum(axis=1) == 2).nonzero()[0][ + 0 + ] # select [True, True] + print(pairlist_output.d_ij[_][0], edge_weight[idx]) + assert torch.allclose( + pairlist_output.d_ij[_][0], + edge_weight[idx], + rtol=1e-3, + ) + print(pairlist_output.r_ij[_], -edge_vec[idx]) + assert torch.allclose( + pairlist_output.r_ij[_], + -1 * edge_vec[idx], + rtol=1e-3, + ) + + +def test_compare_radial_symmetry_features(): + # Compare the TensorNet radial symmetry function to the output of the + # modelforge radial symmetry function TODO: only 'expnorm' from TensorNet + # implemented + import torch + from openff.units import unit + + from modelforge.potential import ( + CosineAttenuationFunction, + TensorNetRadialBasisFunction, + ) + from modelforge.tests.precalculated_values import ( + prepare_values_for_test_tensornet_compare_radial_symmetry_features, + ) + + seed = 0 + torch.manual_seed(seed) + from importlib import resources + + from modelforge.tests import data + + reference_data = resources.files(data) / "tensornet_radial_symmetry_features.pt" + + # generate a random list of distances, all < 5 + d_ij = unit.Quantity( + torch.tensor([[2.4813], [3.8411], [0.4424], [0.6602], [1.5371]]), unit.angstrom + ) + + # TensorNet constants + maximum_interaction_radius = unit.Quantity(5.1, unit.angstrom) + minimum_interaction_radius = unit.Quantity(0.0, unit.angstrom) + number_of_per_atom_features = 8 + alpha = ( + (maximum_interaction_radius - minimum_interaction_radius) + / unit.Quantity(5.0, unit.angstrom) + ).m / 10 + + rsf = TensorNetRadialBasisFunction( + number_of_radial_basis_functions=number_of_per_atom_features, + max_distance=maximum_interaction_radius.to(unit.nanometer).m, + min_distance=minimum_interaction_radius.to(unit.nanometer).m, + alpha=alpha, + ) + mf_r = rsf(d_ij.to(unit.nanometer).m) # torch.Size([5, 8]) + cutoff_module = CosineAttenuationFunction( + maximum_interaction_radius.to(unit.nanometer).m + ) + + rcut_ij = cutoff_module(d_ij.to(unit.nanometer).m) # torch.Size([5, 1]) + mf_r = (mf_r * rcut_ij).unsqueeze(1) + + from importlib import resources + + from modelforge.tests import data + + reference_data = resources.files(data) / "tensornet_radial_symmetry_features.pt" + + if reference_data: + tn_r = torch.load(reference_data) + else: + tn_r = prepare_values_for_test_tensornet_compare_radial_symmetry_features( + d_ij, + minimum_interaction_radius, + maximum_interaction_radius, + number_of_per_atom_features, + trainable=False, + seed=seed, + ) + + assert torch.allclose(mf_r, tn_r, atol=1e-4) + + +def test_representation(prep_temp_dir): + from importlib import resources + + import torch + from openff.units import unit + from torch import nn + + from modelforge.potential.tensornet import TensorNetRepresentation + from modelforge.tests import data + + reference_data = resources.files(data) / "tensornet_representation.pt" + + number_of_per_atom_features = 8 + num_rbf = 16 + act_class = nn.SiLU + cutoff_lower = 0.0 + cutoff_upper = 5.1 + trainable_rbf = False + highest_atomic_number = 128 + + import pickle + + reference_batch = resources.files(data) / "nnp_input.pkl" + nnp_input = pickle.load(open(reference_batch, "rb")) + # -------------------------------# + # -------------------------------# + # Test that we can add the reference energy correctly + # get methane input + model = setup_potential_for_test( + use="inference", + potential_seed=0, + potential_name="tensornet", + simulation_environment="PyTorch", + use_training_mode_neighborlist=True, + local_cache_dir=str(prep_temp_dir), + ) + pairlist_output = model.neighborlist.forward(nnp_input) + + ################ modelforge TensorNet ################ + torch.manual_seed(0) + tensornet_representation_module = TensorNetRepresentation( + number_of_per_atom_features, + num_rbf, + act_class(), + unit.Quantity(cutoff_upper, unit.angstrom).to(unit.nanometer).m, + unit.Quantity(cutoff_lower, unit.angstrom).to(unit.nanometer).m, + trainable_rbf, + highest_atomic_number, + ) + mf_X, _ = tensornet_representation_module(nnp_input, pairlist_output) + ################ modelforge TensorNet ################ + + ################ torchmd-net TensorNet ################ + if reference_data: + tn_X = torch.load(reference_data) + else: + tn_X = prepare_values_for_test_tensornet_representation( + nnp_input, + number_of_per_atom_features, + num_rbf, + act_class, + cutoff_lower, + cutoff_upper, + trainable_rbf, + highest_atomic_number, + seed=0, + ) + ################ torchmd-net TensorNet ################ + + assert mf_X.shape == tn_X.shape + assert torch.allclose(mf_X, tn_X) + + +def test_interaction(prep_temp_dir): + import pickle + from importlib import resources + + import torch + from openff.units import unit + from torch import nn + + from modelforge.potential.tensornet import TensorNetInteraction + from modelforge.tests import data + from modelforge.tests.precalculated_values import ( + prepare_values_for_test_tensornet_interaction, + ) + + seed = 0 + + reference_data = resources.files(data) / "tensornet_interaction.pt" + + reference_batch = resources.files(data) / "nnp_input.pkl" + nnp_input = pickle.load(open(reference_batch, "rb")) + + number_of_per_atom_features = 8 + num_rbf = 16 + act_class = nn.SiLU + cutoff_lower = 0.0 + cutoff_upper = 5.1 + + # -------------------------------# + # -------------------------------# + # initialize model + model = setup_potential_for_test( + use="inference", + potential_seed=seed, + potential_name="tensornet", + simulation_environment="PyTorch", + use_training_mode_neighborlist=True, + local_cache_dir=str(prep_temp_dir), + ) + + ################ modelforge TensorNet ################ + tensornet_representation_module = model.core_network.representation_module + pairlist_output = model.neighborlist.forward(nnp_input) + X, _ = tensornet_representation_module(nnp_input, pairlist_output) + + radial_feature_vector = tensornet_representation_module.radial_symmetry_function( + pairlist_output.d_ij + ) + rcut_ij = tensornet_representation_module.cutoff_module(pairlist_output.d_ij) + radial_feature_vector = (radial_feature_vector * rcut_ij).unsqueeze(1) + + atomic_charges = torch.zeros_like(nnp_input.atomic_numbers) + + # interaction + torch.manual_seed(seed) + interaction_module = TensorNetInteraction( + number_of_per_atom_features, + num_rbf, + act_class(), + unit.Quantity(cutoff_upper, unit.angstrom).to(unit.nanometer).m, + "O(3)", + ) + mf_X = interaction_module( + X, + pairlist_output.pair_indices, + pairlist_output.d_ij.squeeze(-1), + radial_feature_vector.squeeze(1), + atomic_charges, + ) + ################ modelforge TensorNet ################ + + ################ TensorNet ################ + if reference_data: + tn_X = torch.load(reference_data) + else: + tn_X = prepare_values_for_test_tensornet_interaction( + X, + nnp_input, + radial_feature_vector, + atomic_charges, + number_of_per_atom_features, + num_rbf, + act_class, + cutoff_lower, + cutoff_upper, + seed, + ) + ################ TensorNet ################ + + assert mf_X.shape == tn_X.shape + assert torch.allclose(mf_X, tn_X) diff --git a/modelforge/tests/test_training.py b/modelforge/tests/test_training.py index c9b11b26..67f0da6c 100644 --- a/modelforge/tests/test_training.py +++ b/modelforge/tests/test_training.py @@ -1,96 +1,385 @@ import os +import platform + import pytest +import torch -import platform +from modelforge.potential import NeuralNetworkPotentialFactory, _Implemented_NNPs ON_MACOS = platform.system() == "Darwin" IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" -from modelforge.potential import _Implemented_NNPs -from modelforge.potential import NeuralNetworkPotentialFactory -def load_configs(model_name: str, dataset_name: str): +@pytest.fixture(scope="session") +def prep_temp_dir(tmp_path_factory): + fn = tmp_path_factory.mktemp("test_training_temp") + return fn + + +def load_configs_into_pydantic_models( + potential_name: str, dataset_name: str, local_cache_dir: str +): + from importlib import resources + + import toml + from modelforge.tests.data import ( - potential_defaults, - training_defaults, dataset_defaults, + potential_defaults, runtime_defaults, + training_defaults, ) - from importlib import resources - from modelforge.train.training import return_toml_config - potential_path = resources.files(potential_defaults) / f"{model_name.lower()}.toml" + potential_path = ( + resources.files(potential_defaults) / f"{potential_name.lower()}.toml" + ) dataset_path = resources.files(dataset_defaults) / f"{dataset_name.lower()}.toml" - training_path = resources.files(training_defaults) / "default.toml" + training_path = resources.files(training_defaults) / f"default.toml" runtime_path = resources.files(runtime_defaults) / "runtime.toml" - return return_toml_config( - potential_path=potential_path, - dataset_path=dataset_path, - training_path=training_path, - runtime_path=runtime_path, + + training_config_dict = toml.load(training_path) + dataset_config_dict = toml.load(dataset_path) + potential_config_dict = toml.load(potential_path) + runtime_config_dict = toml.load(runtime_path) + + potential_name = potential_config_dict["potential"]["potential_name"] + + from modelforge.potential import _Implemented_NNP_Parameters + + PotentialParameters = ( + _Implemented_NNP_Parameters.get_neural_network_parameter_class(potential_name) ) + potential_parameters = PotentialParameters(**potential_config_dict["potential"]) + from modelforge.dataset.dataset import DatasetParameters + from modelforge.train.parameters import RuntimeParameters, TrainingParameters -@pytest.mark.skipif(ON_MACOS, reason="Skipping this test on MacOS GitHub Actions") -@pytest.mark.parametrize("model_name", _Implemented_NNPs.get_all_neural_network_names()) -@pytest.mark.parametrize("dataset_name", ["QM9"]) -def test_train_with_lightning(model_name, dataset_name): + dataset_parameters = DatasetParameters(**dataset_config_dict["dataset"]) + training_parameters = TrainingParameters(**training_config_dict["training"]) + runtime_parameters = RuntimeParameters(**runtime_config_dict["runtime"]) + + runtime_parameters.local_cache_dir = local_cache_dir + return { + "potential": potential_parameters, + "dataset": dataset_parameters, + "training": training_parameters, + "runtime": runtime_parameters, + } + + +def get_trainer(config): + # Extract parameters + potential_parameter = config["potential"] + training_parameter = config["training"] + dataset_parameter = config["dataset"] + runtime_parameter = config["runtime"] + + return NeuralNetworkPotentialFactory.generate_trainer( + potential_parameter=potential_parameter, + training_parameter=training_parameter, + dataset_parameter=dataset_parameter, + runtime_parameter=runtime_parameter, + ) + + +def add_force_to_loss_parameter(config): """ - Test the forward pass for a given model and dataset. + [training.loss_parameter] + loss_components = ['per_system_energy', 'per_atom_force'] + # ------------------------------------------------------------ # + [training.loss_parameter.weight] + per_system_energy = 0.999 #NOTE: reciprocal units + per_atom_force = 0.001 + """ + t_config = config["training"] + t_config.loss_parameter.loss_components.append("per_atom_force") + t_config.loss_parameter.weight["per_atom_force"] = 0.001 + t_config.loss_parameter.target_weight["per_atom_force"] = 0.001 + t_config.loss_parameter.mixing_steps["per_atom_force"] = 1 - from modelforge.train.training import perform_training - # read default parameters - config = load_configs(model_name, dataset_name) +def add_dipole_moment_to_loss_parameter(config): + """ + [training.loss_parameter] + loss_components = [ + "per_system_energy", + "per_atom_force", + "per_system_dipole_moment", + "per_system_total_charge", + ] + [training.loss_parameter.weight] + per_system_energy = 1 #NOTE: reciprocal units + per_atom_force = 0.1 + per_system_dipole_moment = 0.01 + per_system_total_charge = 0.01 - # Extract parameters - potential_config = config["potential"] - training_config = config["training"] - dataset_config = config["dataset"] - runtime_config = config["runtime"] + """ + t_config = config["training"] + t_config.loss_parameter.loss_components.append("per_system_dipole_moment") + t_config.loss_parameter.loss_components.append("per_system_total_charge") + t_config.loss_parameter.weight["per_system_dipole_moment"] = 0.01 + t_config.loss_parameter.weight["per_system_total_charge"] = 0.01 + + t_config.loss_parameter.target_weight["per_system_dipole_moment"] = 0.01 + t_config.loss_parameter.target_weight["per_system_total_charge"] = 0.01 + + t_config.loss_parameter.mixing_steps["per_system_dipole_moment"] = 1 + t_config.loss_parameter.mixing_steps["per_system_total_charge"] = 1 + + # also add per_atom_charge to predicted properties + + p_config = config["potential"] + p_config.core_parameter.predicted_properties.append("per_atom_charge") + p_config.core_parameter.predicted_dim.append(1) + + +def replace_per_system_with_per_atom_loss(config): + t_config = config["training"] + t_config.loss_parameter.loss_components.remove("per_system_energy") + t_config.loss_parameter.loss_components.append("per_atom_energy") + + t_config.loss_parameter.weight.pop("per_system_energy") + t_config.loss_parameter.weight["per_atom_energy"] = 0.999 + + t_config.loss_parameter.target_weight.pop("per_system_energy") + t_config.loss_parameter.target_weight["per_atom_energy"] = 0.999 + + t_config.loss_parameter.mixing_steps.pop("per_system_energy") + t_config.loss_parameter.mixing_steps["per_atom_energy"] = 1 - # perform training - trainer = perform_training( - potential_config=potential_config, - training_config=training_config, - dataset_config=dataset_config, - runtime_config=runtime_config, + # NOTE: the loss is calculate per_atom, but the validation set error is + # per_system. This is because otherwise it's difficult to compare. + t_config.early_stopping.monitor = "val/per_system_energy/rmse" + t_config.monitor = "val/per_system_energy/rmse" + t_config.lr_scheduler.monitor = "val/per_system_energy/rmse" + + +from typing import Literal + + +from typing import Literal +from modelforge.train.parameters import ( + TrainingParameters, + ReduceLROnPlateauConfig, + CosineAnnealingLRConfig, + CosineAnnealingWarmRestartsConfig, + CyclicLRConfig, + OneCycleLRConfig, +) + +from typing import Literal + + +def use_different_LRScheduler( + training_config: TrainingParameters, + which_one: Literal[ + "CosineAnnealingLR", + "ReduceLROnPlateau", + "CosineAnnealingWarmRestarts", + "OneCycleLR", + "CyclicLR", + ], +) -> TrainingParameters: + """ + Modifies the training configuration to use a different learning rate scheduler. + """ + + if which_one == "ReduceLROnPlateau": + lr_scheduler_config = ReduceLROnPlateauConfig( + scheduler_name="ReduceLROnPlateau", + frequency=1, + interval="epoch", + monitor=training_config.monitor, + mode="min", + factor=0.1, + patience=10, + threshold=0.1, + threshold_mode="abs", + cooldown=5, + min_lr=1e-8, + eps=1e-8, + ) + elif which_one == "CosineAnnealingLR": + lr_scheduler_config = CosineAnnealingLRConfig( + scheduler_name="CosineAnnealingLR", + frequency=1, + interval="epoch", + monitor=training_config.monitor, + T_max=training_config.number_of_epochs, + eta_min=0.0, + last_epoch=-1, + ) + elif which_one == "CosineAnnealingWarmRestarts": + lr_scheduler_config = CosineAnnealingWarmRestartsConfig( + scheduler_name="CosineAnnealingWarmRestarts", + frequency=1, + interval="epoch", + monitor=training_config.monitor, + T_0=10, + T_mult=2, + eta_min=0.0, + last_epoch=-1, + ) + elif which_one == "OneCycleLR": + lr_scheduler_config = OneCycleLRConfig( + scheduler_name="OneCycleLR", + frequency=1, + interval="step", + monitor=None, + max_lr=training_config.lr, + epochs=training_config.number_of_epochs, # Use epochs from training config + # steps_per_epoch will be calculated at runtime + pct_start=0.3, + anneal_strategy="cos", + cycle_momentum=True, + base_momentum=0.85, + max_momentum=0.95, + div_factor=25.0, + final_div_factor=1e4, + three_phase=False, + last_epoch=-1, + ) + elif which_one == "CyclicLR": + lr_scheduler_config = CyclicLRConfig( + scheduler_name="CyclicLR", + frequency=1, + interval="step", + monitor=None, + base_lr=training_config.lr / 10, + max_lr=training_config.lr, + epochs_up=1.0, # For example, increasing phase lasts 1 epoch + epochs_down=1.0, # Decreasing phase lasts 1 epoch + mode="triangular", + gamma=1.0, + scale_mode="cycle", + cycle_momentum=True, + base_momentum=0.8, + max_momentum=0.9, + last_epoch=-1, + ) + else: + raise ValueError(f"Unsupported scheduler: {which_one}") + + # Update the lr_scheduler in the training configuration + training_config.lr_scheduler = lr_scheduler_config + return training_config + + +import pytest +from modelforge.train.parameters import TrainingParameters + + +@pytest.mark.parametrize("potential_name", ["ANI2x"]) +@pytest.mark.parametrize("dataset_name", ["PHALKETHOH"]) +@pytest.mark.parametrize( + "lr_scheduler", + [ + "ReduceLROnPlateau", + "CosineAnnealingLR", + "CosineAnnealingWarmRestarts", + "OneCycleLR", + "CyclicLR", + ], +) +def test_learning_rate_scheduler( + potential_name, + dataset_name, + lr_scheduler, + prep_temp_dir, +): + """ + Test that we can train, save, and load checkpoints with different learning rate schedulers. + """ + # Load the configuration into Pydantic models + config = load_configs_into_pydantic_models( + potential_name, dataset_name, str(prep_temp_dir) ) - # save checkpoint - trainer.save_checkpoint("test.chp") - # continue training - trainer = perform_training( - potential_config=potential_config, - training_config=training_config, - dataset_config=dataset_config, - runtime_config=runtime_config, - checkpoint_path="test.chp", + + # Get the training configuration + training_config = config["training"] + # Modify the training configuration to use the selected scheduler + training_config = use_different_LRScheduler(training_config, lr_scheduler) + + config["training"] = training_config + # Proceed with training + get_trainer(config).train_potential().save_checkpoint("test.chp") # save checkpoint + + +@pytest.mark.xdist_group(name="test_training_with_lightning") +@pytest.mark.skipif(ON_MACOS, reason="Skipping this test on MacOS GitHub Actions") +@pytest.mark.parametrize( + "potential_name", _Implemented_NNPs.get_all_neural_network_names() +) +@pytest.mark.parametrize("dataset_name", ["PHALKETHOH"]) +@pytest.mark.parametrize( + "loss", + ["energy", "energy_force", "normalized_energy_force", "energy_force_dipole_moment"], +) +def test_train_with_lightning(loss, potential_name, dataset_name, prep_temp_dir): + """ + Test that we can train, save and load checkpoints. + """ + + # SKIP if potential is ANI and dataset is SPICE2 + if "ANI" in potential_name and dataset_name == "SPICE2": + pytest.skip("ANI potential is not compatible with SPICE2 dataset") + if IN_GITHUB_ACTIONS and potential_name == "SAKE" and "force" in loss: + pytest.skip( + "Skipping Sake training with forces because it allocates too much memory" + ) + + config = load_configs_into_pydantic_models( + potential_name, dataset_name, str(prep_temp_dir) ) + if "force" in loss: + add_force_to_loss_parameter(config) + if "normalized" in loss: + replace_per_system_with_per_atom_loss(config) + if "dipole_moment" in loss: + add_dipole_moment_to_loss_parameter(config) -import torch + # train potential + get_trainer(config).train_potential().save_checkpoint("test.chp") # save checkpoint + # continue training from checkpoint + # get_trainer(config).train_potential() + + +def test_train_from_single_toml_file(prep_temp_dir): + from importlib import resources + + from modelforge.tests import data + from modelforge.train.training import read_config_and_train + + config_path = resources.files(data) / f"config.toml" + + read_config_and_train(config_path, local_cache_dir=str(prep_temp_dir)) -def test_error_calculation(single_batch_with_batchsize_16_with_force): +def test_error_calculation(single_batch_with_batchsize, prep_temp_dir): # test the different Loss classes - from modelforge.train.training import ( - FromPerAtomToPerMoleculeMeanSquaredError, - PerMoleculeMeanSquaredError, + from modelforge.train.losses import ( + ForceSquaredError, + EnergySquaredError, ) # generate data - data = single_batch_with_batchsize_16_with_force - true_E = data.metadata.E - true_F = data.metadata.F + batch = single_batch_with_batchsize( + batch_size=16, dataset_name="PHALKETHOH", local_cache_dir=str(prep_temp_dir) + ) + + data = batch + true_E = data.metadata.per_system_energy + true_F = data.metadata.per_atom_force # make predictions predicted_E = true_E + torch.rand_like(true_E) * 10 predicted_F = true_F + torch.rand_like(true_F) * 10 # test error for property with shape (nr_of_molecules, 1) - error = PerMoleculeMeanSquaredError() + error = EnergySquaredError() E_error = error(predicted_E, true_E, data) # compare output (mean squared error scaled by number of atoms in the molecule) @@ -100,19 +389,18 @@ def test_error_calculation(single_batch_with_batchsize_16_with_force): 1 ) # FIXME : fi reference_E_error = torch.mean(scale_squared_error) - assert torch.allclose(E_error, reference_E_error) + assert torch.allclose(torch.mean(E_error), reference_E_error) # test error for property with shape (nr_of_atoms, 3) - error = FromPerAtomToPerMoleculeMeanSquaredError() + error = ForceSquaredError() F_error = error(predicted_F, true_F, data) # compare error (mean squared error scaled by number of atoms in the molecule) - scaled_error = ( torch.linalg.vector_norm(predicted_F - true_F, dim=1, keepdim=True) ** 2 ) - per_mol_error = torch.zeros_like(data.metadata.E) + per_mol_error = torch.zeros_like(data.metadata.per_system_energy) per_mol_error.scatter_add_( 0, data.nnp_input.atomic_subsystem_indices.unsqueeze(-1) @@ -122,52 +410,202 @@ def test_error_calculation(single_batch_with_batchsize_16_with_force): ) reference_F_error = torch.mean( - per_mol_error / data.metadata.atomic_subsystem_counts.unsqueeze(1) + per_mol_error / (3 * data.metadata.atomic_subsystem_counts.unsqueeze(1)) ) - assert torch.allclose(F_error, reference_F_error) + assert torch.allclose(torch.mean(F_error), reference_F_error) -@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Skipping this test on GitHub Actions") -@pytest.mark.parametrize("model_name", _Implemented_NNPs.get_all_neural_network_names()) -@pytest.mark.parametrize("dataset_name", ["QM9"]) -def test_hypterparameter_tuning_with_ray( - model_name, - dataset_name, - datamodule_factory, -): - from modelforge.train.training import return_toml_config, LossFactory - from importlib import resources - from modelforge.tests.data import ( - training, - potential_defaults, - dataset_defaults, - training_defaults, +def test_loss_with_dipole_moment(single_batch_with_batchsize, prep_temp_dir): + # Generate a batch with the specified batch size and dataset + batch = single_batch_with_batchsize( + batch_size=16, dataset_name="SPICE2", local_cache_dir=str(prep_temp_dir) ) - config = load_configs(model_name, dataset_name) + # Get the trainer object with the specified model and dataset + config = load_configs_into_pydantic_models( + potential_name="schnet", + dataset_name="SPICE2", + local_cache_dir=str(prep_temp_dir), + ) + add_dipole_moment_to_loss_parameter(config) + add_force_to_loss_parameter(config) - # Extract parameters - potential_config = config["potential"] - training_config = config["training"] - dataset_config = config["dataset"] - runtime_config = config["runtime"] + trainer = get_trainer( + config, + ) - dm = datamodule_factory(dataset_name=dataset_name) + # Calculate predictions using the trainer's model + prediction = trainer.lightning_module.calculate_predictions( + batch, + trainer.lightning_module.potential, + train_mode=True, # train_mode=True is required for gradients in force prediction + ) - # training model - model = NeuralNetworkPotentialFactory.generate_model( - use="training", - model_parameter=potential_config, - training_parameter=training_config["training_parameter"], + # Assertions for energy predictions + assert prediction["per_system_energy_predict"].size( + 0 + ) == batch.metadata.per_system_energy.size( + 0 + ), "Mismatch in batch size for energy predictions." + + # Assertions for force predictions + assert prediction["per_atom_force_predict"].size( + 0 + ) == batch.metadata.per_atom_force.size( + 0 + ), "Mismatch in number of atoms for force predictions." + + # Assertions for dipole moment predictions + assert ( + "per_system_dipole_moment_predict" in prediction + ), "Dipole moment prediction missing." + assert ( + prediction["per_system_dipole_moment_predict"].size() + == batch.metadata.per_system_dipole_moment.size() + ), "Mismatch in shape for dipole moment predictions." + + # Assertions for total charge predictions + assert ( + "per_system_total_charge_predict" in prediction + ), "Total charge prediction missing." + assert ( + prediction["per_system_total_charge_predict"].size() + == batch.nnp_input.per_system_total_charge.size() + ), "Mismatch in shape for total charge predictions." + + # Now compute the loss + loss_dict = trainer.lightning_module.loss( + predict_target=prediction, + batch=batch, + epoch_idx=0, ) - from modelforge.train.tuning import RayTuner + # Ensure that the loss contains the total_charge and dipole_moment terms + assert "per_system_total_charge" in loss_dict, "Total charge loss not computed." + assert "per_system_dipole_moment" in loss_dict, "Dipole moment loss not computed." + + # Check that the losses are finite numbers + assert torch.isfinite( + loss_dict["per_system_total_charge"] + ).all(), "Total charge loss contains non-finite values." + assert torch.isfinite( + loss_dict["per_system_dipole_moment"] + ).all(), "Dipole moment loss contains non-finite values." + + # Optionally, print or log the losses for debugging + print("Total Charge Loss:", loss_dict["per_system_total_charge"].mean().item()) + print("Dipole Moment Loss:", loss_dict["per_system_dipole_moment"].mean().item()) + + # Check that the total loss includes the new loss terms + assert "total_loss" in loss_dict, "Total loss not computed." + assert torch.isfinite( + loss_dict["total_loss"] + ).all(), "Total loss contains non-finite values." + + +def test_loss(single_batch_with_batchsize, prep_temp_dir): + from modelforge.train.losses import Loss + + batch = single_batch_with_batchsize( + batch_size=16, dataset_name="PHALKETHOH", local_cache_dir=str(prep_temp_dir) + ) + + loss_porperty = ["per_system_energy", "per_atom_force", "per_atom_energy"] + loss_weights = { + "per_system_energy": torch.tensor([0.5]), + "per_atom_force": torch.tensor([0.5]), + "per_atom_energy": torch.tensor([0.1]), + } + loss = Loss(loss_porperty, loss_weights) + assert loss is not None + + # Get the trainer object with the specified model and dataset + config = load_configs_into_pydantic_models( + potential_name="schnet", dataset_name="QM9", local_cache_dir=str(prep_temp_dir) + ) + add_force_to_loss_parameter(config) + + trainer = get_trainer( + config, + ) + prediction = trainer.lightning_module.calculate_predictions( + batch, trainer.lightning_module.potential, train_mode=True + ) # train_mode=True is required for gradients in force prediction + + assert prediction["per_system_energy_predict"].size( + dim=0 + ) == batch.metadata.per_system_energy.size(dim=0) + assert prediction["per_atom_force_predict"].size( + dim=0 + ) == batch.metadata.per_atom_force.size(dim=0) + + # pass prediction through loss module + loss_output = loss(prediction, batch, epoch_idx=0) + # let's recalculate the loss (NOTE: we scale the loss by the number of atoms) + # --------------------------------------------- # + # make sure that both have gradients + assert prediction["per_system_energy_predict"].requires_grad + assert prediction["per_atom_force_predict"].requires_grad + + # --------------------------------------------- # + # first, calculate E_loss + E_loss = torch.mean( + ( + ( + prediction["per_system_energy_predict"] + - prediction["per_system_energy_true"] + ).pow(2) + ) + ) + # compare to reference evalue obtained from Loos class + ref = torch.mean(loss_output["per_system_energy"]) + assert torch.allclose(ref, E_loss) + E_loss = torch.mean( + ( + ( + prediction["per_system_energy_predict"] + - prediction["per_system_energy_true"] + ).pow(2) + / batch.metadata.atomic_subsystem_counts.unsqueeze(1) + ) + ) + # compare to reference evalue obtained from Loos class + ref = torch.mean(loss_output["per_atom_energy"]) + assert torch.allclose(ref, E_loss) + + # --------------------------------------------- # + # now calculate F_loss + per_atom_force_squared_error = ( + (prediction["per_atom_force_predict"] - prediction["per_atom_force_true"]) + .pow(2) + .sum(dim=1, keepdim=True) + ).squeeze(-1) + + # # Aggregate error per molecule + per_system_squared_error = torch.zeros_like( + batch.metadata.per_system_energy.squeeze(-1), + dtype=per_atom_force_squared_error.dtype, + ) + per_system_squared_error.scatter_add_( + 0, + batch.nnp_input.atomic_subsystem_indices.long(), + per_atom_force_squared_error, + ) + # divide by number of atoms + per_system_squared_error = per_system_squared_error / ( + 3 * batch.metadata.atomic_subsystem_counts + ) + + per_atom_force_mse = torch.mean(per_system_squared_error) + assert torch.allclose(torch.mean(loss_output["per_atom_force"]), per_atom_force_mse) + + # --------------------------------------------- # + # let's double check that the loss is calculated correctly + # calculate the total loss - ray_tuner = RayTuner(model) - ray_tuner.tune_with_ray( - train_dataloader=dm.train_dataloader(), - val_dataloader=dm.val_dataloader(), - number_of_ray_workers=1, - number_of_epochs=1, - number_of_samples=1, + assert torch.allclose( + loss_weights["per_system_energy"] * loss_output["per_system_energy"] + + loss_weights["per_atom_force"] * loss_output["per_atom_force"] + + +loss_weights["per_atom_energy"] * loss_output["per_atom_energy"], + loss_output["total_loss"].to(torch.float32), ) diff --git a/modelforge/tests/test_utils.py b/modelforge/tests/test_utils.py index d2e410fa..694df5af 100644 --- a/modelforge/tests/test_utils.py +++ b/modelforge/tests/test_utils.py @@ -1,24 +1,153 @@ import numpy as np import torch import pytest +import platform +import os +ON_MACOS = platform.system() == "Darwin" +IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" @pytest.fixture(scope="session") def prep_temp_dir(tmp_path_factory): - fn = tmp_path_factory.mktemp("utils_test") + fn = tmp_path_factory.mktemp("test_utils_temp") return fn +@pytest.mark.parametrize( + "partial_point_charges, atomic_subsystem_indices, total_charge", + [ + ( + torch.zeros(6, 1), + torch.tensor([0, 0, 1, 1, 1, 1], dtype=torch.int64), + torch.tensor([0.0, 1.0]).unsqueeze(1), + ), + ( + torch.zeros(6, 1), + torch.tensor([0, 0, 1, 1, 1, 1], dtype=torch.int64), + torch.tensor([-1.0, 2.0]).unsqueeze(1), + ), + ( + torch.rand(6, 1), + torch.tensor([0, 0, 1, 1, 1, 1], dtype=torch.int64), + torch.tensor([-1.0, 2.0]).unsqueeze(1), + ), + ], +) +def test_default_charge_conservation( + partial_point_charges: torch.Tensor, + atomic_subsystem_indices: torch.Tensor, + total_charge: torch.Tensor, +): + """ + Test the default_charge_conservation function with various test cases. + """ + from modelforge.potential.processing import default_charge_conservation + + # test charge equilibration + # ------------------------- # + charges = default_charge_conservation( + partial_point_charges, + total_charge, + atomic_subsystem_indices, + ) + + # Calculate the total charge per molecule after correction + predicted_total_charge = torch.zeros_like(total_charge).scatter_add_( + 0, atomic_subsystem_indices.unsqueeze(1), charges + ) + + # Assert that the predicted total charges match the desired total charges + assert torch.allclose(predicted_total_charge, total_charge, atol=1e-6) + + +@pytest.mark.skipif( + ON_MACOS, + reason="Skipt Test on MacOS CI runners as it relies on spawning multiple threads. ", +) +def test_method_locking(tmp_path): + """ + Test the lock_with_attribute decorator to ensure that it correctly serializes access + to a critical section across multiple processes. + """ + import multiprocessing + from modelforge.utils.misc import lock_with_attribute + import time + + # Define a class with a method decorated by lock_with_attribute + class TestClass: + def __init__(self, lock_file): + self.method_lock = lock_file + + @lock_with_attribute("method_lock") + def critical_section(self, shared_list): + process_name = multiprocessing.current_process().name + # Record entry into the critical section + shared_list.append(f"{process_name} entered") + # Simulate work + time.sleep(1) + # Record exit from the critical section + shared_list.append(f"{process_name} exited") + + # Worker function to be executed by each process + def worker(lock_file, shared_list): + test_obj = TestClass(lock_file) + test_obj.critical_section(shared_list) + + # Create a Manager to handle shared state across processes + manager = multiprocessing.Manager() + shared_list = manager.list() + + # Path to the lock file within the pytest-provided temporary directory + lock_file_path = tmp_path / "test.lock" + + # List to hold process objects + processes = [] + + # Create and start multiple processes + processes = [] + for i in range(4): + p = multiprocessing.Process( + target=worker, + args=(str(lock_file_path), shared_list), + name=f"Process-{i+1}", + ) + p.start() + processes.append(p) + + # Wait for all processes to complete + for p in processes: + p.join() + + # Verify that only one process was in the critical section at any given time + entered = set() + for entry in shared_list: + if "entered" in entry: + process = entry.split()[0] + # Ensure no other process is in the critical section + assert ( + len(entered) == 0 + ), f"{process} entered while {entered} was in the critical section" + entered.add(process) + elif "exited" in entry: + process = entry.split()[0] + # Ensure the process that exits was the one that entered + assert process in entered, f"{process} exited without entering" + entered.remove(process) + + # Ensure all processes have exited the critical section + assert len(entered) == 0, f"Processes left in critical section: {entered}" + + def test_dense_layer(): - from modelforge.potential.utils import Dense + from modelforge.potential.utils import DenseWithCustomDist import torch # random 2x3 torch.Tensor x = torch.randn(2, 3) # create a Dense layer with 3 input features and 2 output features - dense_layer = Dense(in_features=3, out_features=2) + dense_layer = DenseWithCustomDist(in_features=3, out_features=2) out = dense_layer(x) # create a Dense layer with 3 input features and 2 output features @@ -27,7 +156,7 @@ def test_dense_layer(): [torch.nn.init.zeros_, torch.nn.init.ones_], ): # test the weight initialization and correct weight multiplication - dense_layer = Dense( + dense_layer = DenseWithCustomDist( in_features=3, out_features=2, bias=False, weight_init=weight_init_fn ) @@ -36,7 +165,7 @@ def test_dense_layer(): manuel_out = dense_layer.weight @ x.T assert torch.allclose(out, manuel_out.T) # test bias - dense_layer = Dense( + dense_layer = DenseWithCustomDist( in_features=3, out_features=2, bias=True, @@ -78,342 +207,66 @@ def test_cosine_cutoff(): """ Test the cosine cutoff implementation. """ - from modelforge.potential.utils import CosineCutoff + from modelforge.potential import CosineAttenuationFunction + # Define inputs x = torch.Tensor([1, 2, 3]) y = torch.Tensor([4, 5, 6]) from openff.units import unit - cutoff = 6 - # Calculate expected output + cutoff = 6 d_ij = torch.linalg.norm(x - y) expected_output = 0.5 * (torch.cos(d_ij * np.pi / cutoff) + 1.0) - cutoff = 0.6 * unit.nanometer # Calculate actual output - cutoff_module = CosineCutoff(cutoff) + cutoff = 0.6 + cutoff_module = CosineAttenuationFunction(cutoff) actual_output = cutoff_module(d_ij / 10) - # Check if the results are equal - # NOTE: Cutoff function doesn't care about the units as long as they are the same + # Check if the results are equal NOTE: Cutoff function doesn't care about + # the units as long as they are the same assert np.isclose(actual_output, expected_output) - # input in angstrom - cutoff = 2.0 * unit.angstrom - expected_output = torch.tensor([0.5, 0.0, 0.0]) - cosine_cutoff_module = CosineCutoff(cutoff) def test_cosine_cutoff_module(): - # Test CosineCutoff module - from modelforge.potential.utils import CosineCutoff + # Test CosineAttenuationFunction module + from modelforge.potential import CosineAttenuationFunction from openff.units import unit - # test the cutoff on this distance vector (NOTE: it is in angstrom) + # test the cutoff on this distance vector d_ij_angstrom = torch.tensor([1.0, 2.0, 3.0]).unsqueeze(1) # the expected outcome is that entry 1 and 2 become zero # and entry 0 becomes 0.5 (since the cutoff is 2.0 angstrom) # input in angstrom - cutoff = 2.0 * unit.angstrom - + cutoff = 2.0 expected_output = torch.tensor([0.5, 0.0, 0.0]).unsqueeze(1) - cosine_cutoff_module = CosineCutoff(cutoff) + cosine_cutoff_module = CosineAttenuationFunction(cutoff / 10) output = cosine_cutoff_module(d_ij_angstrom / 10) # input is in nanometer assert torch.allclose(output, expected_output, rtol=1e-3) -def test_radial_symmetry_function_implementation(): - """ - Test the Radial Symmetry function implementation. - """ - import torch +def test_PhysNetAttenuationFunction(): + from modelforge.potential.representation import PhysNetAttenuationFunction from openff.units import unit - import numpy as np - from modelforge.potential.utils import CosineCutoff, GaussianRadialBasisFunctionWithScaling - - cutoff_module = CosineCutoff(cutoff=unit.Quantity(5.0, unit.angstrom)) - - class RadialSymmetryFunctionTest(GaussianRadialBasisFunctionWithScaling): - @staticmethod - def calculate_radial_basis_centers( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, - ): - centers = torch.linspace( - _min_distance_in_nanometer, - _max_distance_in_nanometer, - number_of_radial_basis_functions, - dtype=dtype, - ) - return centers - - @staticmethod - def calculate_radial_scale_factor( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype - ): - scale_factors = torch.full( - (number_of_radial_basis_functions,), - (_min_distance_in_nanometer - _max_distance_in_nanometer) - / number_of_radial_basis_functions, - ) - scale_factors = (scale_factors * -15_000) ** -0.5 - return scale_factors - - - RSF = RadialSymmetryFunctionTest( - number_of_radial_basis_functions=18, - max_distance=unit.Quantity(5.0, unit.angstrom), - ) - print(f"{RSF.radial_basis_centers=}") - print(f"{RSF.radial_scale_factor=}") - # test a single distance - d_ij = torch.tensor([[0.2]]) - radial_expension = RSF(d_ij) - - expected_output = np.array( - [ - 5.7777413e-08, - 5.4214674e-06, - 2.4740110e-04, - 5.4905377e-03, - 5.9259072e-02, - 3.1104434e-01, - 7.9399312e-01, - 9.8568588e-01, - 5.9509689e-01, - 1.7472850e-01, - 2.4949821e-02, - 1.7326004e-03, - 5.8513560e-05, - 9.6104134e-07, - 7.6763511e-09, - 2.9819147e-11, - 5.6333109e-14, - 5.1755549e-17, - ], - dtype=np.float32, - ) + import torch - assert np.allclose(radial_expension.numpy().flatten(), expected_output, rtol=1e-3) + # test the cutoff on this distance vector (NOTE: it is in angstrom) + d_ij_angstrom = torch.tensor([1.0, 2.0, 3.0]).unsqueeze(1) + # the expected outcome is that entry 1 and 2 become zero + # and entry 0 becomes 0.5 (since the cutoff is 2.0 angstrom) + # input in angstrom + cutoff = 2.0 * unit.angstrom - # test multiple distances with cutoff - d_ij = torch.tensor([[r] for r in np.linspace(0, 0.5, 10)]) - radial_expension = RSF(d_ij) * cutoff_module(d_ij) + expected_output = torch.tensor([0.5, 0.0, 0.0]).unsqueeze(1) + physnet_cutoff_module = PhysNetAttenuationFunction(cutoff.to(unit.nanometer).m) - expected_output = np.array( - [ - [ - 1.00000000e00, - 6.97370611e-01, - 2.36512753e-01, - 3.90097089e-02, - 3.12909145e-03, - 1.22064879e-04, - 2.31574554e-06, - 2.13657562e-08, - 9.58678574e-11, - 2.09196141e-13, - 2.22005077e-16, - 1.14577532e-19, - 2.87583090e-23, - 3.51038337e-27, - 2.08388175e-31, - 6.01615362e-36, - 8.44679753e-41, - 5.76756600e-46, - ], - [ - 2.68038176e-01, - 7.29490887e-01, - 9.65540222e-01, - 6.21510012e-01, - 1.94559846e-01, - 2.96200218e-02, - 2.19303227e-03, - 7.89645189e-05, - 1.38275834e-06, - 1.17757010e-08, - 4.87703136e-11, - 9.82316969e-14, - 9.62221521e-17, - 4.58380155e-20, - 1.06194951e-23, - 1.19649050e-27, - 6.55604552e-32, - 1.74703654e-36, - ], - [ - 5.15165267e-03, - 5.47178933e-02, - 2.82643788e-01, - 7.10030194e-01, - 8.67443988e-01, - 5.15386799e-01, - 1.48919812e-01, - 2.09266151e-02, - 1.43012111e-03, - 4.75305832e-05, - 7.68248035e-07, - 6.03888967e-09, - 2.30855409e-11, - 4.29190731e-14, - 3.88050222e-17, - 1.70629005e-20, - 3.64875837e-24, - 3.79458837e-28, - ], - [ - 7.05512776e-06, - 2.92447055e-04, - 5.89544925e-03, - 5.77981439e-02, - 2.75573882e-01, - 6.38983424e-01, - 7.20556963e-01, - 3.95161266e-01, - 1.05392022e-01, - 1.36699807e-02, - 8.62294776e-04, - 2.64527563e-05, - 3.94651201e-07, - 2.86340809e-09, - 1.01036987e-11, - 1.73382336e-14, - 1.44696036e-17, - 5.87267193e-21, - ], - [ - 6.79841545e-10, - 1.09978970e-07, - 8.65244557e-06, - 3.31051436e-04, - 6.15997825e-03, - 5.57430086e-02, - 2.45317579e-01, - 5.25042257e-01, - 5.46496226e-01, - 2.76635027e-01, - 6.81011682e-02, - 8.15322217e-03, - 4.74713206e-04, - 1.34419004e-05, - 1.85104660e-07, - 1.23965647e-09, - 4.03750130e-12, - 6.39515861e-15, - ], - [ - 4.50275565e-15, - 2.84275808e-12, - 8.72828077e-10, - 1.30330158e-07, - 9.46429271e-06, - 3.34240505e-04, - 5.74059467e-03, - 4.79492711e-02, - 1.94775558e-01, - 3.84781601e-01, - 3.69675978e-01, - 1.72725113e-01, - 3.92479574e-02, - 4.33716512e-03, - 2.33089213e-04, - 6.09208166e-06, - 7.74348707e-08, - 4.78668403e-10, - ], - [ - 1.95755731e-21, - 4.82320349e-18, - 5.77941614e-15, - 3.36790471e-12, - 9.54470642e-10, - 1.31550705e-07, - 8.81760261e-06, - 2.87432047e-04, - 4.55666577e-03, - 3.51307040e-02, - 1.31720486e-01, - 2.40185684e-01, - 2.12994423e-01, - 9.18579329e-02, - 1.92660386e-02, - 1.96514909e-03, - 9.74823310e-05, - 2.35170925e-06, - ], - [ - 5.02685557e-29, - 4.83367095e-25, - 2.26039866e-21, - 5.14067897e-18, - 5.68568509e-15, - 3.05825078e-12, - 7.99999982e-10, - 1.01773376e-07, - 6.29659424e-06, - 1.89454631e-04, - 2.77224261e-03, - 1.97280685e-02, - 6.82755520e-02, - 1.14914067e-01, - 9.40607618e-02, - 3.74430407e-02, - 7.24871542e-03, - 6.82461743e-04, - ], - [ - 5.43174696e-38, - 2.03835481e-33, - 3.72003749e-29, - 3.30173621e-25, - 1.42516199e-21, - 2.99167362e-18, - 3.05415190e-15, - 1.51633227e-12, - 3.66121676e-10, - 4.29917177e-08, - 2.45510656e-06, - 6.81841165e-05, - 9.20923191e-04, - 6.04910223e-03, - 1.93234986e-02, - 3.00198084e-02, - 2.26807491e-02, - 8.33362963e-03, - ], - [ - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - 0.00000000e00, - ], - ] - ) + output = physnet_cutoff_module(d_ij_angstrom / 10) # input is in nanometer - assert np.allclose(radial_expension.numpy(), expected_output, rtol=1e-3) + assert torch.allclose(output, expected_output, rtol=1e-3) def test_scatter_add(): @@ -448,55 +301,28 @@ def test_scatter_softmax(): assert torch.allclose(util_out, correct_out) -def test_embedding(): - """ - Test the Embedding module. - """ - from modelforge.potential.utils import Embedding - - max_Z = 100 - embedding_dim = 7 - - # Create Embedding instance - embedding = Embedding(max_Z, embedding_dim) - - # Test embedding_dim property - assert embedding.embedding_dim == embedding_dim - - # Test forward pass - input_tensor = torch.randint(0, 99, (5,)) - - output = embedding(input_tensor) - assert output.shape == (5, embedding_dim) - - def test_energy_readout(): from modelforge.potential.processing import FromAtomToMoleculeReduction import torch - # the EnergyReadout module performs a linear pass to reduce the nr_of_atom_basis to 1 - # and then performs a scatter add operation to return a tensor with size [nr_of_molecules,] + # the EnergyReadout module performs a linear pass to reduce the + # nr_of_atom_basis to 1 and then performs a scatter add operation to return + # a tensor with size [nr_of_molecules,] # the input for the EnergyReadout module is vector (E_i) that will be scatter_added, and # a second tensor supplying the indixes for the summation r = { - "per_atom_energy": torch.tensor([3, 3, 1, 1, 1, 1, 1, 1], dtype=torch.float32), + "per_atom_energy": torch.tensor( + [3, 3, 1, 1, 1, 1, 1, 1], dtype=torch.float32 + ).unsqueeze(1), "atomic_subsystem_index": torch.tensor([0, 0, 1, 1, 1, 1, 1, 1]), } - energy_readout = FromAtomToMoleculeReduction( - per_atom_property_name="per_atom_energy", - index_name="atomic_subsystem_index", - output_name="per_molecule_energy", - ) - E = energy_readout(r)["per_molecule_energy"] + energy_readout = FromAtomToMoleculeReduction() + E = energy_readout(r["atomic_subsystem_index"], r["per_atom_energy"]) # check that output has length of total number of molecules in batch - assert E.size() == torch.Size( - [ - 2, - ] - ) + assert E.size() == torch.Size([2, 1]) # check that the correct values were summed assert torch.isclose(E[0], torch.tensor([6.0], dtype=torch.float32)) assert torch.isclose(E[1], torch.tensor([6.0], dtype=torch.float32)) @@ -526,6 +352,10 @@ def test_welford(): assert np.isclose(online_estimator.stddev / target_stddev, 1.0, rtol=1e-1) +@pytest.mark.skipif( + ON_MACOS and IN_GITHUB_ACTIONS, + reason="Test is flaky on the MacOS CI runners as it relies on spawning multiple threads. ", +) def test_filelocking(prep_temp_dir): from modelforge.utils.misc import lock_file, unlock_file, check_file_lock @@ -548,7 +378,7 @@ def run(self): if not check_file_lock(f): lock_file(f) self.did_I_lock_it = True - time.sleep(2) + time.sleep(3) unlock_file(f) else: @@ -557,7 +387,8 @@ def run(self): # the first thread should lock the file and set "did_I_lock_it" to True thread1 = thread("lock_file_here", "Thread-1", filepath) # the second thread should check if locked, and set "did_I_lock_it" to False - # the second thread should also set "status" to True, because it waits for the first thread to unlock the file + # the second thread should also set "status" to True, because it waits + # for the first thread to unlock the file thread2 = thread("encounter_locked_file", "Thread-2", filepath) thread1.start() diff --git a/modelforge/train/__init__.py b/modelforge/train/__init__.py index e69de29b..5b4a8d9b 100644 --- a/modelforge/train/__init__.py +++ b/modelforge/train/__init__.py @@ -0,0 +1 @@ +"""Module that contains classes and functions for traning neural network potentials.""" diff --git a/modelforge/train/losses.py b/modelforge/train/losses.py new file mode 100644 index 00000000..efa64937 --- /dev/null +++ b/modelforge/train/losses.py @@ -0,0 +1,500 @@ +# losses.py + +""" +This module contains classes and functions for loss computation and error metrics +for training neural network potentials. +""" + +from abc import ABC, abstractmethod +from typing import Dict, List +import torch +from torch import nn +from loguru import logger as log + +from modelforge.dataset.dataset import NNPInput + +__all__ = [ + "Error", + "ForceSquaredError", + "EnergySquaredError", + "TotalChargeError", + "DipoleMomentError", + "Loss", + "LossFactory", + "create_error_metrics", +] + + +class Error(nn.Module, ABC): + """ + Abstract base class for error calculation between predicted and true values. + """ + + def __init__(self, scale_by_number_of_atoms: bool = True): + super().__init__() + self.scale_by_number_of_atoms = ( + self._scale_by_number_of_atoms + if scale_by_number_of_atoms + else lambda error, atomic_subsystem_counts, prefactor=1: error + ) + + @abstractmethod + def calculate_error( + self, + predicted: torch.Tensor, + true: torch.Tensor, + ) -> torch.Tensor: + """ + Calculates the error between the predicted and true values. + """ + raise NotImplementedError + + @staticmethod + def calculate_squared_error( + predicted_tensor: torch.Tensor, reference_tensor: torch.Tensor + ) -> torch.Tensor: + """ + Calculates the squared error between the predicted and true values. + """ + squared_diff = (predicted_tensor - reference_tensor).pow(2) + error = squared_diff.sum(dim=1, keepdim=True) + return error + + @staticmethod + def _scale_by_number_of_atoms( + error, atomic_counts, prefactor: int = 1 + ) -> torch.Tensor: + """ + Scales the error by the number of atoms in the atomic subsystems. + + Parameters + ---------- + error : torch.Tensor + The error to be scaled. + atomic_counts : torch.Tensor + The number of atoms in the atomic subsystems. + prefactor : int + Prefactor to adjust for the shape of the property (e.g., vector properties). + + Returns + ------- + torch.Tensor + The scaled error. + """ + scaled_by_number_of_atoms = error / (prefactor * atomic_counts.unsqueeze(1)) + return scaled_by_number_of_atoms + + +class ForceSquaredError(Error): + """ + Calculates the per-atom error and aggregates it to per-system mean squared error. + """ + + def calculate_error( + self, + per_atom_prediction: torch.Tensor, + per_atom_reference: torch.Tensor, + ) -> torch.Tensor: + """Computes the per-atom squared error.""" + return self.calculate_squared_error(per_atom_prediction, per_atom_reference) + + def forward( + self, + per_atom_prediction: torch.Tensor, + per_atom_reference: torch.Tensor, + batch: NNPInput, + ) -> torch.Tensor: + """ + Computes the per-atom error and aggregates it to per-system mean squared error. + + Parameters + ---------- + per_atom_prediction : torch.Tensor + The predicted values. + per_atom_reference : torch.Tensor + The reference values provided by the dataset. + batch : NNPInput + The batch data containing metadata and input information. + + Returns + ------- + torch.Tensor + The aggregated per-system error. + """ + + # Compute per-atom squared error + per_atom_squared_error = self.calculate_error( + per_atom_prediction, per_atom_reference + ) + + # Initialize per-system squared error tensor + per_system_squared_error = torch.zeros_like( + batch.metadata.per_system_energy, dtype=per_atom_squared_error.dtype + ) + + # Aggregate error per system + per_system_squared_error = per_system_squared_error.scatter_add( + 0, + batch.nnp_input.atomic_subsystem_indices.long().unsqueeze(1), + per_atom_squared_error, + ) + + # Scale error by number of atoms + per_system_square_error_scaled = self.scale_by_number_of_atoms( + per_system_squared_error, + batch.metadata.atomic_subsystem_counts, + prefactor=per_atom_prediction.shape[-1], + ) + + return per_system_square_error_scaled.contiguous() + + +class EnergySquaredError(Error): + """ + Calculates the per-system mean squared error. + """ + + def calculate_error( + self, + per_system_prediction: torch.Tensor, + per_system_reference: torch.Tensor, + ) -> torch.Tensor: + """Computes the per-system squared error.""" + return self.calculate_squared_error(per_system_prediction, per_system_reference) + + def forward( + self, + per_system_prediction: torch.Tensor, + per_system_reference: torch.Tensor, + batch: NNPInput, + ) -> torch.Tensor: + """ + Computes the per-system mean squared error. + + Parameters + ---------- + per_system_prediction : torch.Tensor + The predicted values. + per_system_reference : torch.Tensor + The true values. + batch : NNPInput + The batch data containing metadata and input information. + + Returns + ------- + torch.Tensor + The mean per-system error. + """ + + # Compute per-system squared error + per_system_squared_error = self.calculate_error( + per_system_prediction, per_system_reference + ) + # Scale error by number of atoms + per_system_square_error_scaled = self.scale_by_number_of_atoms( + per_system_squared_error, + batch.metadata.atomic_subsystem_counts, + ) + + return per_system_square_error_scaled + + +class TotalChargeError(Error): + """ + Calculates the error for total charge. + """ + + def calculate_error( + self, + total_charge_predict: torch.Tensor, + total_charge_true: torch.Tensor, + ) -> torch.Tensor: + """ + Computes the absolute difference between predicted and true total charges. + """ + error = torch.abs(total_charge_predict - total_charge_true) + return error # Shape: [batch_size, 1] + + def forward( + self, + total_charge_predict: torch.Tensor, + total_charge_true: torch.Tensor, + batch: NNPInput, + ) -> torch.Tensor: + """ + Computes the error for total charge. + + Parameters + ---------- + total_charge_predict : torch.Tensor + The predicted total charges. + total_charge_true : torch.Tensor + The true total charges. + batch : NNPInput + The batch data. + + Returns + ------- + torch.Tensor + The error for total charges. + """ + error = self.calculate_error(total_charge_predict, total_charge_true) + return error # No scaling needed + + +class DipoleMomentError(Error): + """ + Calculates the error for dipole moment. + """ + + def calculate_error( + self, + dipole_predict: torch.Tensor, + dipole_true: torch.Tensor, + ) -> torch.Tensor: + """ + Computes the squared difference between predicted and true dipole moments. + """ + error = ( + (dipole_predict - dipole_true).pow(2).sum(dim=1, keepdim=True) + ) # Shape: [batch_size, 1] + return error + + def forward( + self, + dipole_predict: torch.Tensor, + dipole_true: torch.Tensor, + batch: NNPInput, + ) -> torch.Tensor: + """ + Computes the error for dipole moment. + + Parameters + ---------- + dipole_predict : torch.Tensor + The predicted dipole moments. + dipole_true : torch.Tensor + The true dipole moments. + batch : NNPInput + The batch data. + + Returns + ------- + torch.Tensor + The error for dipole moments. + """ + error = self.calculate_error(dipole_predict, dipole_true) + return error # No scaling needed + + +class Loss(nn.Module): + + _SUPPORTED_PROPERTIES = [ + "per_atom_energy", + "per_system_energy", + "per_atom_force", + "per_system_total_charge", + "per_system_dipole_moment", + ] + + def __init__( + self, + loss_components: List[str], + weights_scheduling: Dict[str, torch.Tensor], + ): + """ + Calculates the combined loss for energy and force predictions. + + Parameters + ---------- + loss_components : List[str] + List of properties to include in the loss calculation. + weights : Dict[str, float] + Dictionary containing the weights for each property in the loss calculation. + + Raises + ------ + NotImplementedError + If an unsupported loss type is specified. + """ + super().__init__() + from torch.nn import ModuleDict + + self.loss_components = loss_components + self.weights_scheduling = weights_scheduling + self.loss_functions = ModuleDict() + + for prop in loss_components: + if prop not in self._SUPPORTED_PROPERTIES: + raise NotImplementedError(f"Loss type {prop} not implemented.") + + log.info(f"Using loss function for {prop}") + log.info( + f"With loss component schedule from weight: {weights_scheduling[prop][0]} to {weights_scheduling[prop][-1]}" + ) + + if prop == "per_atom_force": + self.loss_functions[prop] = ForceSquaredError( + scale_by_number_of_atoms=True + ) + elif prop == "per_atom_energy": + self.loss_functions[prop] = EnergySquaredError( + scale_by_number_of_atoms=True + ) + elif prop == "per_system_energy": + self.loss_functions[prop] = EnergySquaredError( + scale_by_number_of_atoms=False + ) + elif prop == "per_system_total_charge": + self.loss_functions[prop] = TotalChargeError() + elif prop == "per_system_dipole_moment": + self.loss_functions[prop] = DipoleMomentError() + else: + raise NotImplementedError(f"Loss type {prop} not implemented.") + + self.register_buffer(prop, self.weights_scheduling[prop]) + + def forward( + self, + predict_target: Dict[str, torch.Tensor], + batch: NNPInput, + epoch_idx: int, + ) -> Dict[str, torch.Tensor]: + """ + Calculates the combined loss for the specified properties. + + Parameters + ---------- + predict_target : Dict[str, torch.Tensor] + Dictionary containing predicted and true values for energy and forces. + batch : NNPInput + The batch data containing metadata and input information. + + Returns + ------- + Dict[str, torch.Tensor] + Individual per-sample loss terms and the combined total loss. + """ + from modelforge.train.training import ( + _exchange_per_atom_energy_for_per_system_energy, + ) + + # Save the loss as a dictionary + loss_dict = {} + # Accumulate loss + total_loss = torch.zeros_like(batch.metadata.per_system_energy) + + # Iterate over loss properties + for prop in self.loss_components: + loss_fn = self.loss_functions[prop] + + prop_ = _exchange_per_atom_energy_for_per_system_energy(prop) + # NOTE: we always predict per_system_energies, and the dataset + # also include per_system_energies. If we are normalizing these + # (indicated by the `per_atom_energy` keyword), we still operate on + # the per_system_energies but the loss function will divide the + # error by the number of atoms in the atomic subsystems. + prop_loss = loss_fn( + predict_target[f"{prop_}_predict"], + predict_target[f"{prop_}_true"], + batch, + ) + # check that none of the tensors are NaN + if torch.isnan(prop_loss).any(): + raise ValueError(f"NaN values detected in {prop} loss.") + + # Accumulate weighted per-sample losses + weighted_loss = self.weights_scheduling[prop][epoch_idx] * prop_loss + + total_loss += weighted_loss # Note: total_loss is still per-sample + loss_dict[prop] = prop_loss # Store per-sample loss + + # Add total loss to results dict and return + loss_dict["total_loss"] = total_loss + + return loss_dict + + +class LossFactory: + """ + Factory class to create different types of loss functions. + """ + + @staticmethod + def create_loss( + loss_components: List[str], + weights_scheduling: Dict[ + str, + torch.Tensor, + ], + ) -> Loss: + """ + Creates an instance of the specified loss type. + + Parameters + ---------- + loss_components : List[str] + List of properties to include in the loss calculation. + weights_scheduling : Dict[str, torch.Tensor] + Dictionary containing the weights for each property in the loss calculation. + + Returns + ------- + Loss + An instance of the specified loss function. + """ + return Loss( + loss_components, + weights_scheduling, + ) + + +from torch.nn import ModuleDict + + +def create_error_metrics( + loss_properties: List[str], + is_loss: bool = False, +) -> ModuleDict: + """ + Creates a ModuleDict of MetricCollections for the given loss properties. + + Parameters + ---------- + loss_properties : List[str] + List of loss properties for which to create the metrics. + is_loss : bool, optional + If True, only the loss metric is created, by default False. + + Returns + ------- + ModuleDict + A dictionary where keys are loss properties and values are MetricCollections. + """ + from torchmetrics import MetricCollection + from torchmetrics.aggregation import MeanMetric + from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError + + if is_loss: + metric_dict = ModuleDict( + {prop: MetricCollection([MeanMetric()]) for prop in loss_properties} + ) + metric_dict["total_loss"] = MetricCollection([MeanMetric()]) + else: + from modelforge.train.training import ( + _exchange_per_atom_energy_for_per_system_energy, + ) + + # NOTE: we are using the + # _exchange_per_atom_energy_for_per_system_energy function because, if + # the `per_atom_energy` loss (i.e., the normalize per_system_energy + # loss) is used, the validation error is still per_system_energy + metric_dict = ModuleDict( + { + _exchange_per_atom_energy_for_per_system_energy(prop): MetricCollection( + [MeanAbsoluteError(), MeanSquaredError(squared=False)] + ) # only exchange per_atom_energy for per_system_energy + for prop in loss_properties + } + ) + return metric_dict diff --git a/modelforge/train/parameters.py b/modelforge/train/parameters.py new file mode 100644 index 00000000..648c6654 --- /dev/null +++ b/modelforge/train/parameters.py @@ -0,0 +1,465 @@ +""" +This defines pydantic models for the training parameters and runtime parameters. +""" + +from enum import Enum +from typing import ( + Callable, + Dict, + List, + Optional, + Type, + Union, + Literal, + Annotated, +) + +import torch +from pydantic import BaseModel, ConfigDict, field_validator, model_validator, Field +from loguru import logger as log + + +# So we do not need to set Config parameters in each model +# we can create a base class that will be inherited by all models +class ParametersBase(BaseModel): + model_config = ConfigDict( + use_enum_values=True, arbitrary_types_allowed=True, validate_assignment=True + ) + + +# for enums over strings, we likely do not want things to be case sensitive +# this will return the appropriate enum value regardless of case +class CaseInsensitiveEnum(str, Enum): + @classmethod + def _missing_(cls, value): + for member in cls: + if member.value.lower() == value.lower(): + return member + return super()._missing_(value) + + +class SchedulerMode(CaseInsensitiveEnum): + """ + Enum class for the scheduler mode, allowing "min" and "max" + """ + + min = "min" + max = "max" + + +class ThresholdMode(CaseInsensitiveEnum): + """ + Enum class for the threshold mode, allowing "abs" and "rel" + """ + + abs = "abs" + rel = "rel" + + +class SchedulerName(CaseInsensitiveEnum): + """ + Enum class for the scheduler names + """ + + ReduceLROnPlateau = "ReduceLROnPlateau" + CosineAnnealingLR = "CosineAnnealingLR" + CosineAnnealingWarmRestarts = "CosineAnnealingWarmRestarts" + OneCycleLR = "OneCycleLR" + CyclicLR = "CyclicLR" + + +class SplittingStrategyName(CaseInsensitiveEnum): + """ + Enum class for the splitting strategy name + """ + + first_come_first_serve = "first_come_first_serve" + random_record_splitting_strategy = "random_record_splitting_strategy" + random_conformer_splitting_strategy = "random_conformer_splitting_strategy" + + +class AnnealingStrategy(CaseInsensitiveEnum): + """ + Enum class for the annealing strategy + """ + + cos = "cos" + linear = "linear" + + +class Loggers(CaseInsensitiveEnum): + """ + Enum class for the experiment logger + """ + + wandb = "wandb" + tensorboard = "tensorboard" + + +class TensorboardConfig(ParametersBase): + save_dir: str + + +class WandbConfig(ParametersBase): + save_dir: str + project: str + group: str + log_model: Union[str, bool] + job_type: Optional[str] + tags: Optional[List[str]] + notes: Optional[str] + + +class SchedulerConfigBase(ParametersBase): + """ + Base class for scheduler configurations + """ + + scheduler_name: SchedulerName + frequency: int + interval: str + monitor: Optional[str] = None + + +class ReduceLROnPlateauConfig(SchedulerConfigBase): + """ + Configuration for ReduceLROnPlateau scheduler + """ + + scheduler_name: Literal[SchedulerName.ReduceLROnPlateau] = ( + SchedulerName.ReduceLROnPlateau + ) + mode: SchedulerMode # 'min' or 'max' + factor: float + patience: int + threshold: float + threshold_mode: ThresholdMode # 'rel' or 'abs' + cooldown: int + min_lr: float + eps: float = 1e-8 + + +class CosineAnnealingLRConfig(SchedulerConfigBase): + """ + Configuration for CosineAnnealingLR scheduler + """ + + scheduler_name: Literal[SchedulerName.CosineAnnealingLR] = ( + SchedulerName.CosineAnnealingLR + ) + T_max: int + eta_min: float = 0.0 + last_epoch: int = -1 + + +class CosineAnnealingWarmRestartsConfig(SchedulerConfigBase): + """ + Configuration for CosineAnnealingWarmRestarts scheduler + """ + + scheduler_name: Literal[SchedulerName.CosineAnnealingWarmRestarts] = ( + SchedulerName.CosineAnnealingWarmRestarts + ) + T_0: int + T_mult: int = 1 + eta_min: float = 0.0 + last_epoch: int = -1 + + +class CyclicLRMode(CaseInsensitiveEnum): + """ + Enum class for the CyclicLR modes + """ + + triangular = "triangular" + triangular2 = "triangular2" + exp_range = "exp_range" + + +class ScaleMode(CaseInsensitiveEnum): + """ + Enum class for the scale modes + """ + + cycle = "cycle" + iterations = "iterations" + + +class OneCycleLRConfig(SchedulerConfigBase): + """ + Configuration for OneCycleLR scheduler + """ + + scheduler_name: Literal[SchedulerName.OneCycleLR] = SchedulerName.OneCycleLR + max_lr: Union[float, List[float]] + epochs: int # required + pct_start: float = 0.3 + anneal_strategy: AnnealingStrategy = AnnealingStrategy.cos + cycle_momentum: bool = True + base_momentum: Union[float, List[float]] = 0.85 + max_momentum: Union[float, List[float]] = 0.95 + div_factor: float = 25.0 + final_div_factor: float = 1e4 + three_phase: bool = False + last_epoch: int = -1 + + @model_validator(mode="after") + def validate_epochs(self): + if self.epochs is None: + raise ValueError("OneCycleLR requires 'epochs' to be set.") + if self.interval != "step": + raise ValueError("OneCycleLR requires 'interval' to be set to 'step'.") + + return self + + +class CyclicLRConfig(SchedulerConfigBase): + """ + Configuration for CyclicLR scheduler + """ + + scheduler_name: Literal[SchedulerName.CyclicLR] = SchedulerName.CyclicLR + base_lr: Union[float, List[float]] + max_lr: Union[float, List[float]] + epochs_up: float # Duration of the increasing phase in epochs + epochs_down: Optional[float] = ( + None # Duration of the decreasing phase in epochs (optional) + ) + mode: CyclicLRMode = CyclicLRMode.triangular + gamma: float = 1.0 + scale_mode: ScaleMode = ScaleMode.cycle + cycle_momentum: bool = True + base_momentum: Union[float, List[float]] = 0.8 + max_momentum: Union[float, List[float]] = 0.9 + last_epoch: int = -1 + + @model_validator(mode="after") + def validate_epochs(self): + if self.epochs_up is None: + raise ValueError("CyclicLR requires 'epochs_up' to be set.") + + if self.interval != "step": + raise ValueError("OneCycleLR requires 'interval' to be set to 'step'.") + return self + + +SchedulerConfig = Annotated[ + Union[ + ReduceLROnPlateauConfig, + CosineAnnealingLRConfig, + CosineAnnealingWarmRestartsConfig, + OneCycleLRConfig, + CyclicLRConfig, + ], + Field(discriminator="scheduler_name"), +] + + +class TrainingParameters(ParametersBase): + """ + A class to hold the training parameters + """ + + class LossParameter(ParametersBase): + """ + Class to hold the loss properties and mixing scheme + """ + + loss_components: List[str] + weight: Dict[str, float] + target_weight: Optional[Dict[str, float]] = None + mixing_steps: Optional[Dict[str, float]] = None + + @model_validator(mode="before") + def set_target_weight_defaults(cls, values): + + # if no target_weight is provided, set target_weight to be the same + # as weight and mixing_steps to be 1.0 + if "target_weight" not in values: + # set target_weight to be the same as weight + values["target_weight"] = values["weight"] + d = {} + for key in values["target_weight"]: + d[key] = 1.0 + values["mixing_steps"] = d + return values + + # if target weight is not provided for all properties, set the rest + # to be the same as weight + for key in values["weight"]: + # if only a subset of the loss components are provided in target_weight, set the rest to be the same as weight + if key not in values["target_weight"]: + values["target_weight"][key] = values["weight"][key] + values["mixing_steps"][key] = 1.0 + return values + + @model_validator(mode="after") + def ensure_length_match(self) -> "LossParameter": + loss_components = self.loss_components + weight = self.weight + if len(loss_components) != len(weight): + raise ValueError( + f"The length of loss_components ({len(loss_components)}) and weight ({len(weight)}) must match." + ) + if set(loss_components) != set(weight.keys()): + raise ValueError("The keys of weight must match loss_components") + + if self.target_weight or self.mixing_steps: + if not (self.target_weight and self.mixing_steps): + raise ValueError( + "If using mixing scheme, target_weight and mixing_steps must all be provided" + ) + if set(self.target_weight.keys()) != set(loss_components): + raise ValueError( + "The keys of target_weight must match loss_components" + ) + return self + + class EarlyStopping(ParametersBase): + """ + Class to hold the early stopping parameters + """ + + verbose: bool + monitor: Optional[str] = None + min_delta: float + patience: int + + class SplittingStrategy(ParametersBase): + """ + Class to hold the splitting strategy + """ + + name: SplittingStrategyName = ( + SplittingStrategyName.random_record_splitting_strategy + ) + data_split: List[float] + seed: int + + @field_validator("data_split") + def data_split_must_sum_to_one_and_length_three(cls, v) -> List[float]: + if len(v) != 3: + raise ValueError("data_split must have length of 3") + if sum(v) != 1: + raise ValueError("data_split must sum to 1") + return v + + class StochasticWeightAveraging(ParametersBase): + """ + Class to hold the stochastic weight averaging parameters + """ + + swa_lrs: Union[float, List[float]] + swa_epoch_start: float + annealing_epoch: int + annealing_strategy: AnnealingStrategy + avg_fn: Optional[Callable] = None + + class ExperimentLogger(ParametersBase): + logger_name: Loggers + tensorboard_configuration: Optional[TensorboardConfig] = None + wandb_configuration: Optional[WandbConfig] = None + + @model_validator(mode="after") + def ensure_logger_configuration(self) -> "ExperimentLogger": + if ( + self.logger_name == Loggers.tensorboard + and self.tensorboard_configuration is None + ): + raise ValueError("tensorboard_configuration must be provided") + if self.logger_name == Loggers.wandb and self.wandb_configuration is None: + raise ValueError("wandb_configuration must be provided") + return self + + monitor: str + number_of_epochs: int + remove_self_energies: bool + shift_center_of_mass_to_origin: bool + batch_size: int + lr: float + plot_frequency: int = 5 # how often to log regression and error histograms + lr_scheduler: Optional[SchedulerConfig] = None + loss_parameter: LossParameter + early_stopping: Optional[EarlyStopping] = None + splitting_strategy: SplittingStrategy + stochastic_weight_averaging: Optional[StochasticWeightAveraging] = None + experiment_logger: ExperimentLogger + verbose: bool = False + log_norm: bool = False + limit_train_batches: Union[float, int, None] = None + limit_val_batches: Union[float, int, None] = None + limit_test_batches: Union[float, int, None] = None + optimizer: Type[torch.optim.Optimizer] = torch.optim.AdamW + min_number_of_epochs: Union[int, None] = None + + @model_validator(mode="after") + def validate_dipole_and_shift_com(self): + if "dipole_moment" in self.loss_parameter.loss_components: + if not self.shift_center_of_mass_to_origin: + raise ValueError( + "Use of dipole_moment in the loss requires shift_center_of_mass_to_origin to be True" + ) + return self + + # Validator to set default monitors + @model_validator(mode="after") + def set_default_monitors(self) -> "TrainingParameters": + if self.lr_scheduler and self.lr_scheduler.monitor is None: + self.lr_scheduler.monitor = self.monitor + if self.early_stopping and self.early_stopping.monitor is None: + self.early_stopping.monitor = self.monitor + return self + + +class Accelerator(CaseInsensitiveEnum): + """ + Enum class for the accelerator, allowing "cpu", "gpu" and "tpu". + """ + + cpu = "cpu" + gpu = "gpu" + tpu = "tpu" + + +class SimulationEnvironment(CaseInsensitiveEnum): + """ + Enum class for the simulation environment, allowing "PyTorch" and "JAX". + """ + + JAX = "JAX" + PyTorch = "PyTorch" + + +class RuntimeParameters(ParametersBase): + """ + A class to hold the runtime parameters + """ + + experiment_name: str + accelerator: Accelerator + number_of_nodes: int + devices: Union[int, List[int]] + local_cache_dir: str + checkpoint_path: Union[str, None] + simulation_environment: SimulationEnvironment + log_every_n_steps: int + verbose: bool + + @field_validator("number_of_nodes") + @classmethod + def number_of_nodes_must_be_positive(cls, v) -> int: + if v < 1: + raise ValueError("number_of_nodes must be positive and greater than 0") + return v + + @field_validator("devices") + @classmethod + def device_index_must_be_positive(cls, v) -> Union[int, List[int]]: + if isinstance(v, list): + for device in v: + if device < 0: + raise ValueError("device_index must be positive") + else: + if v < 0: + raise ValueError("device_index must be positive") + return v diff --git a/modelforge/train/training.py b/modelforge/train/training.py index 260d5e5d..1d71eb7e 100644 --- a/modelforge/train/training.py +++ b/modelforge/train/training.py @@ -1,547 +1,632 @@ -from torch.optim.lr_scheduler import ReduceLROnPlateau -import lightning as pl -from typing import Any, Union, Dict, Type, Optional, List +""" +This module contains classes and functions for training neural network potentials using PyTorch Lightning. +""" + +from typing import Any, Dict, List, Optional, Type, TypeVar, Tuple, Literal +import time +import lightning.pytorch as pL import torch +from lightning import Trainer from loguru import logger as log -from modelforge.dataset.dataset import BatchData, NNPInput -import torchmetrics -from torch import nn -from abc import ABC, abstractmethod +from openff.units import unit +from torch.nn import ModuleDict +from torch.optim import Optimizer +from torch.optim.lr_scheduler import ( + ReduceLROnPlateau, + CosineAnnealingLR, + CosineAnnealingWarmRestarts, + OneCycleLR, + CyclicLR, +) + +from modelforge.dataset.dataset import DataModule, DatasetParameters +from modelforge.potential.parameters import ( + AimNet2Parameters, + ANI2xParameters, + PaiNNParameters, + PhysNetParameters, + SAKEParameters, + SchNetParameters, + TensorNetParameters, +) +from modelforge.utils.prop import BatchData + +T_NNP_Parameters = TypeVar( + "T_NNP_Parameters", + ANI2xParameters, + SAKEParameters, + SchNetParameters, + PhysNetParameters, + PaiNNParameters, + TensorNetParameters, + AimNet2Parameters, +) + +from modelforge.train.losses import LossFactory, create_error_metrics +from modelforge.train.parameters import RuntimeParameters, TrainingParameters + +__all__ = [ + "PotentialTrainer", +] + + +def gradient_norm(model): + """ + Compute the total gradient norm of a model. + Parameters + ---------- + model : torch.nn.Module + The neural network model. -class Error(nn.Module, ABC): + Returns + ------- + float + The total gradient norm. """ - Class representing the error calculation for predicted and true values. + total_norm = 0.0 + for p in model.parameters(): + if p.grad is not None: + param_norm = p.grad.detach().data.norm(2) + total_norm += param_norm.item() ** 2 + total_norm = total_norm**0.5 + return total_norm - Methods: - calculate_error(predicted: torch.Tensor, true: torch.Tensor) -> torch.Tensor: - Calculates the error between the predicted and true values. - scale_by_number_of_atoms(error, atomic_subsystem_counts) -> torch.Tensor: - Scales the error by the number of atoms in the atomic subsystems. +def compute_grad_norm(loss, model): """ + Compute the gradient norm of the loss with respect to the model parameters. - @abstractmethod - def calculate_error( - self, predicted: torch.Tensor, true: torch.Tensor - ) -> torch.Tensor: - """ - Calculates the error between the predicted and true values - """ - raise NotImplementedError - - @staticmethod - def calculate_squared_error( - predicted_tensor: torch.Tensor, reference_tensor: torch.Tensor - ) -> torch.Tensor: - """ - Calculates the squared error between the predicted and true values. + Parameters + ---------- + loss : torch.Tensor + The loss tensor. + model : torch.nn.Module + The neural network model. - Parameters: - predicted_tensor (torch.Tensor): The predicted values. - reference_tensor (torch.Tensor): The values provided by the dataset. + Returns + ------- + float + The total gradient norm. + """ + parameters = [p for p in model.parameters() if p.requires_grad] + grads = torch.autograd.grad( + loss.sum(), + parameters, + retain_graph=True, + create_graph=False, + allow_unused=True, + ) + total_norm = 0.0 + for grad in grads: + if grad is not None: + param_norm = grad.detach().data.norm(2) + total_norm += param_norm.item() ** 2 + total_norm = total_norm**0.5 + return total_norm - Returns: - torch.Tensor: The calculated error. - """ - return (predicted_tensor - reference_tensor).pow(2).sum(dim=1, keepdim=True) - @staticmethod - def scale_by_number_of_atoms(error, atomic_subsystem_counts) -> torch.Tensor: - """ - Scales the error by the number of atoms in the atomic subsystems. +def _exchange_per_atom_energy_for_per_system_energy(prop: str) -> str: + """ + Rename 'per_atom_energy' to 'per_system_energy' if applicable. - Parameters: - error: The error to be scaled. - atomic_subsystem_counts: The number of atoms in the atomic subsystems. + Parameters + ---------- + prop : str + The property name (e.g., "per_atom_energy"). - Returns: - torch.Tensor: The scaled error. - """ - # divide by number of atoms - scaled_by_number_of_atoms = error / atomic_subsystem_counts.unsqueeze( - 1 - ) # FIXME: ensure that all per-atom properties have dimension (N, 1) - return scaled_by_number_of_atoms + Returns + ------- + str + The updated property name (e.g., "per_system_energy"). + """ + return "per_system_energy" if prop == "per_atom_energy" else prop -class FromPerAtomToPerMoleculeMeanSquaredError(Error): - """ - Calculates the per-atom error and aggregates it to per-molecule mean squared error. - """ +class CalculateProperties(torch.nn.Module): + _SUPPORTED_PROPERTIES = [ + "per_atom_energy", + "per_atom_force", + "per_system_energy", + "per_system_total_charge", + "per_system_dipole_moment", + ] - def __init__(self): + def __init__(self, requested_properties: List[str]): """ - Initializes the PerAtomToPerMoleculeError class. + A utility class for calculating properties such as energies and forces + from batches using a neural network model. + + Parameters + ---------- + requested_properties : List[str] + A list of properties to calculate (e.g., per_atom_energy, + per_atom_force, per_system_dipole_moment). """ super().__init__() + self.requested_properties = requested_properties + self.include_force = "per_atom_force" in self.requested_properties + self.include_charges = "per_system_total_charge" in self.requested_properties - def calculate_error( - self, - per_atom_prediction: torch.Tensor, - per_atom_reference: torch.Tensor, - ) -> torch.Tensor: - """ - Computes the per-atom error. - """ - return self.calculate_squared_error(per_atom_prediction, per_atom_reference) + # Ensure all requested properties are supported + assert all( + prop in self._SUPPORTED_PROPERTIES for prop in self.requested_properties + ), f"Unsupported property requested: {self.requested_properties}" - def forward( - self, - per_atom_prediction: torch.Tensor, - per_atom_reference: torch.Tensor, - batch: "NNPInput", - ) -> torch.Tensor: + @staticmethod + def _get_forces( + batch: BatchData, + model_prediction: Dict[str, torch.Tensor], + train_mode: bool, + ) -> Dict[str, torch.Tensor]: """ - Computes the per-atom error and aggregates it to per-molecule mean squared error. + Computes the forces from a given batch using the model. Parameters ---------- - per_atom_prediction : torch.Tensor - The predicted values. - per_atom_reference : torch.Tensor - The reference values provided by the dataset. - batch : NNPInput - The batch data containing metadata and input information. - + batch : BatchData + A single batch of data, including input features and target + energies. + model_prediction : Dict[str, torch.Tensor] + A dictionary containing the predicted energies from the model. + train_mode : bool + Whether to retain the graph for gradient computation (True for + training). Returns ------- - torch.Tensor - The aggregated per-molecule error. + Dict[str, torch.Tensor] + A dictionary containing the true and predicted forces. """ + nnp_input = batch.nnp_input + nnp_input.positions.requires_grad_(True) # Ensure gradients are enabled + # Cast to float32 and extract true forces + per_atom_force_true = batch.metadata.per_atom_force.to(torch.float32) - # squared error - per_atom_squared_error = self.calculate_error( - per_atom_prediction, per_atom_reference - ) - - per_molecule_squared_error = torch.zeros_like( - batch.metadata.E, dtype=per_atom_squared_error.dtype - ) - # Aggregate error per molecule - - per_molecule_squared_error.scatter_add_( - 0, - batch.nnp_input.atomic_subsystem_indices.long().unsqueeze(1), - per_atom_squared_error, - ) - # divide by number of atoms - per_molecule_square_error_scaled = self.scale_by_number_of_atoms( - per_molecule_squared_error, batch.metadata.atomic_subsystem_counts - ) - # return the average - return torch.mean(per_molecule_square_error_scaled) - + if per_atom_force_true.numel() < 1: + raise RuntimeError("No force can be calculated.") -class PerMoleculeMeanSquaredError(Error): - """ - Calculates the per-molecule mean squared error. + # Sum the energies before computing the gradient + total_energy = model_prediction["per_system_energy"].sum() + # Calculate forces as the negative gradient of energy w.r.t. positions + grad = torch.autograd.grad( + total_energy, + nnp_input.positions, + create_graph=train_mode, + retain_graph=train_mode, + allow_unused=False, + )[0] - """ + if grad is None: + raise RuntimeWarning("Force calculation did not return a gradient") - def __init__(self): - """ - Initializes the PerMoleculeMeanSquaredError class. - """ + per_atom_force_predict = ( + -grad.contiguous() + ) # Forces are the negative gradient of energy - super().__init__() + return { + "per_atom_force_true": per_atom_force_true, + "per_atom_force_predict": per_atom_force_predict, + } - def forward( - self, - per_molecule_prediction: torch.Tensor, - per_molecule_reference: torch.Tensor, - batch, - ) -> torch.Tensor: + @staticmethod + def _get_energies( + batch: BatchData, + model_prediction: Dict[str, torch.Tensor], + ) -> Dict[str, torch.Tensor]: """ - Computes the per-molecule mean squared error. + Compute the energies from a given batch using the model. Parameters ---------- - per_molecule_prediction : torch.Tensor - The predicted values. - per_molecule_reference : torch.Tensor - The true values. - batch : Any - The batch data containing metadata and input information. + batch : BatchData + A single batch of data, including input features and target + energies. + model_prediction : Dict[str, torch.Tensor] + The neural network model used to compute the energies. Returns ------- - torch.Tensor - The mean per-molecule error. + Dict[str, torch.Tensor] + A dictionary containing the true and predicted energies. """ + per_system_energy_true = batch.metadata.per_system_energy.to(torch.float32) + per_system_energy_predict = model_prediction["per_system_energy"] - per_molecule_squared_error = self.calculate_error( - per_molecule_prediction, per_molecule_reference - ) - per_molecule_square_error_scaled = self.scale_by_number_of_atoms( - per_molecule_squared_error, batch.metadata.atomic_subsystem_counts + # Ensure the shapes match + assert per_system_energy_true.shape == per_system_energy_predict.shape, ( + f"Shapes of true and predicted energies do not match: " + f"{per_system_energy_true.shape} != {per_system_energy_predict.shape}" ) + return { + "per_system_energy_true": per_system_energy_true, + "per_system_energy_predict": per_system_energy_predict, + } - # return the average - return torch.mean(per_molecule_square_error_scaled) - - def calculate_error( + def _get_charges( self, - per_atom_prediction: torch.Tensor, - per_atom_reference: torch.Tensor, - ) -> torch.Tensor: - """ - Computes the per-atom error. - """ - return self.calculate_squared_error(per_atom_prediction, per_atom_reference) - - -class Loss(nn.Module): - """ - Calculates the combined loss for energy and force predictions. - - Attributes - ---------- - loss_property : List[str] - List of properties to include in the loss calculation. - weight : Dict[str, float] - Dictionary containing the weights for each property in the loss calculation. - loss : nn.ModuleDict - Module dictionary containing the loss functions for each property. - """ - - _SUPPORTED_PROPERTIES = ["per_molecule_energy", "per_atom_force"] - - def __init__(self, loss_porperty: List[str], weight: Dict[str, float]): + batch: BatchData, + model_prediction: Dict[str, torch.Tensor], + ) -> Dict[str, torch.Tensor]: """ - Initializes the Loss class. + Compute total molecular charges and dipole moments from the predicted atomic charges. Parameters ---------- - loss_property : List[str] - List of properties to include in the loss calculation. - weight : Dict[str, float] - Dictionary containing the weights for each property in the loss calculation. + batch : BatchData + A batch of data containing input features and target charges. + model_prediction : Dict[str, torch.Tensor] + A dictionary containing the predicted charges from the model. - Raises - ------ - NotImplementedError - If an unsupported loss type is specified. + Returns + ------- + Dict[str, torch.Tensor] + A dictionary containing the true and predicted charges and dipole moments. """ - super().__init__() - from torch.nn import ModuleDict - - self.loss_property = loss_porperty - self.weight = weight - - self.loss = ModuleDict() + nnp_input = batch.nnp_input + per_atom_charges_predict = model_prediction[ + "per_atom_charge" + ] # Shape: (nr_of_atoms, 1) + + # Calculate predicted total charge by summing per-atom charges for each + # system + per_system_total_charge_predict = torch.zeros_like( + model_prediction["per_system_energy"] + ).scatter_add_( + dim=0, + index=nnp_input.atomic_subsystem_indices.long().unsqueeze(1), + src=per_atom_charges_predict, + ) # Shape: [nr_of_systems, 1] + + # Predict the dipole moment + per_system_dipole_moment = self._predict_dipole_moment(model_prediction, batch) - for prop, w in weight.items(): - if prop in self._SUPPORTED_PROPERTIES: - if prop == "per_atom_force": - self.loss[prop] = FromPerAtomToPerMoleculeMeanSquaredError() - elif prop == "per_molecule_energy": - self.loss[prop] = PerMoleculeMeanSquaredError() - self.register_buffer(prop, torch.tensor(w)) - else: - raise NotImplementedError(f"Loss type {prop} not implemented.") + return { + "per_system_total_charge_predict": per_system_total_charge_predict, + "per_system_total_charge_true": batch.nnp_input.per_system_total_charge, + "per_system_dipole_moment_predict": per_system_dipole_moment, + "per_system_dipole_moment_true": batch.metadata.per_system_dipole_moment, + } - def forward(self, predict_target: Dict[str, torch.Tensor], batch): + @staticmethod + def _predict_dipole_moment( + model_predictions: Dict[str, torch.Tensor], batch: BatchData + ) -> torch.Tensor: """ - Calculates the combined loss for the specified properties. + Compute the predicted dipole moment for each system based on the + predicted partial atomic charges and positions, i.e., the dipole moment + is calculated as the weighted sum of the partial charges (which requires + that the coordinates are centered). + + The dipole moment ensures that the predicted charges not only sum up to + the correct total charge but also reproduce the reference dipole moment. Parameters ---------- - predict_target : Dict[str, torch.Tensor] - Dictionary containing predicted and true values for energy and per_atom_force. - batch : Any - The batch data containing metadata and input information. + model_predictions : Dict[str, torch.Tensor] + A dictionary containing the predicted atomic charges from the model. + batch : BatchData + A batch of data containing the atomic positions and indices. Returns ------- - Dict{str, torch.Tensor] - Individual loss terms and the combined, total loss. - """ - # save the loss as a dictionary - loss_dict = {} - # accumulate loss - loss = torch.tensor( - [0.0], dtype=batch.metadata.E.dtype, device=batch.metadata.E.device - ) - # iterate over loss properties - for prop in self.loss_property: - # calculate loss per property - loss_ = self.weight[prop] * self.loss[prop]( - predict_target[f"{prop}_predict"], predict_target[f"{prop}_true"], batch - ) - # add total loss - loss = loss + loss_ - # save loss - loss_dict[f"{prop}/mse"] = loss_ - - # add total loss to results dict and return - loss_dict["total_loss"] = loss - - return loss_dict - - -class LossFactory(object): - """ - Factory class to create different types of loss functions. - """ + torch.Tensor + The predicted dipole moment for each system. + """ + per_atom_charge = model_predictions["per_atom_charge"] # Shape: [num_atoms, 1] + positions = batch.nnp_input.positions # Shape: [num_atoms, 3] + per_atom_charge = per_atom_charge # Shape: [num_atoms, 1] + per_atom_dipole_contrib = per_atom_charge * positions # Shape: [num_atoms, 3] + + indices = batch.nnp_input.atomic_subsystem_indices.long() # Shape: [num_atoms] + indices = indices.unsqueeze(-1).expand(-1, 3) # Shape: [num_atoms, 3] + + # Calculate dipole moment as the sum of dipole contributions for each + # system + dipole_predict = torch.zeros( + (model_predictions["per_system_energy"].shape[0], 3), + device=positions.device, + dtype=positions.dtype, + ).scatter_add_( + dim=0, + index=indices, + src=per_atom_dipole_contrib, + ) # Shape: [nr_of_systems, 3] + + return dipole_predict - @staticmethod - def create_loss(loss_property: List[str], weight: Dict[str, float]) -> Type[Loss]: + def forward( + self, + batch: BatchData, + model: torch.nn.Module, + train_mode: bool = False, + ) -> Dict[str, torch.Tensor]: """ - Creates an instance of the specified loss type. + Computes energies, forces, and charges from a given batch using the + model. Parameters ---------- - loss_property : List[str] - List of properties to include in the loss calculation. - weight : Dict[str, float] - Dictionary containing the weights for each property in the loss calculation. + batch : BatchData + A single batch of data, including input features and target + energies. + model : Type[torch.nn.Module] + The neural network model used to compute the properties. + train_mode : bool, optional + Whether to calculate gradients for forces (default is False). + Returns ------- - Loss - An instance of the specified loss function. + Dict[str, torch.Tensor] + The true and predicted energies and forces from the dataset and the + model. """ + predict_target = {} + nnp_input = batch.nnp_input + model_prediction = model.forward(nnp_input) - return Loss(loss_property, weight) - - -from torch.optim import Optimizer - - -from torch.nn import ModuleDict - - -def create_error_metrics(loss_properties: List[str]) -> ModuleDict: - """ - Creates a ModuleDict of MetricCollections for the given loss properties. + # Get predicted energies + energies = self._get_energies(batch, model_prediction) + predict_target.update(energies) - Parameters - ---------- - loss_properties : List[str] - List of loss properties for which to create the metrics. + # Get forces if they are included in the requested properties + if self.include_force: + forces = self._get_forces(batch, model_prediction, train_mode) + predict_target.update(forces) - Returns - ------- - ModuleDict - A dictionary where keys are loss properties and values are MetricCollections. - """ - from torchmetrics.regression import ( - MeanAbsoluteError, - MeanSquaredError, - ) - from torchmetrics import MetricCollection + # Get charges if they are included in the requested properties + if self.include_charges: + charges = self._get_charges(batch, model_prediction) + predict_target.update(charges) - return ModuleDict( - { - prop: MetricCollection( - [MeanAbsoluteError(), MeanSquaredError(squared=False)] - ) - for prop in loss_properties - } - ) + return predict_target -class TrainingAdapter(pl.LightningModule): +class TrainingAdapter(pL.LightningModule): """ - Adapter class for training neural network potentials using PyTorch Lightning. + A Lightning module that encapsulates the training process for neural network potentials. """ def __init__( self, *, - lr_scheduler_config: Dict[str, Union[str, int, float]], - model_parameter: Dict[str, Any], - lr: float, - loss_parameter: Dict[str, Any], - dataset_statistic: Optional[Dict[str, float]] = None, - optimizer: Type[Optimizer] = torch.optim.AdamW, - verbose: bool = False, + potential_parameter: T_NNP_Parameters, + dataset_statistic: Dict[str, Dict[str, unit.Quantity]], + training_parameter: TrainingParameters, + optimizer_class: Type[Optimizer], + nr_of_training_batches: int = -1, + potential_seed: Optional[int] = None, ): """ - Initializes the TrainingAdapter with the specified model and training configuration. + Initialize the TrainingAdapter with model and training configuration. Parameters ---------- - nnp_parameters : Dict[str, Any] - The parameters for the neural network potential model. - lr_scheduler_config : Dict[str, Union[str, int, float]] - The configuration for the learning rate scheduler. - lr : float - The learning rate for the optimizer. - loss_module : Loss, optional - optimizer : Type[Optimizer], optional - The optimizer class to use for training, by default torch.optim.AdamW. + potential_parameter : T_NNP_Parameters + Parameters for the potential model. + dataset_statistic : Dict[str, Dict[str, unit.Quantity]] + Dataset statistics such as mean and standard deviation. + training_parameter : TrainingParameters + Training configuration, including optimizer, learning rate, and loss functions. + optimizer_class : Type[Optimizer] + The optimizer class to use for training. + nr_of_training_batches : int, optional + Number of training batches (default is -1). + potential_seed : Optional[int], optional + Seed for initializing the model (default is None). """ + from modelforge.potential.potential import setup_potential - from modelforge.potential import _Implemented_NNPs + self.epoch_start_time = None super().__init__() self.save_hyperparameters() + self.training_parameter = training_parameter - # Get requested model class - model_name = model_parameter["model_name"] - nnp_class: Type = _Implemented_NNPs.get_neural_network_class(model_name) - - # initialize model - self.model = nnp_class( - **model_parameter["core_parameter"], + # Setup the potential model + self.potential = setup_potential( + potential_parameter=potential_parameter, dataset_statistic=dataset_statistic, - postprocessing_parameter=model_parameter["postprocessing_parameter"], + potential_seed=potential_seed, + jit=False, + use_training_mode_neighborlist=True, ) - self.optimizer = optimizer - self.learning_rate = lr - self.lr_scheduler_config = lr_scheduler_config - - # verbose output, only True if requested - if verbose: - self.log_histograms = True - self.log_on_training_step = True - else: - self.log_histograms = False - self.log_on_training_step = False - - # initialize loss - self.loss = LossFactory.create_loss(**loss_parameter) - - # Assign the created error metrics to the respective attributes - self.test_error = create_error_metrics(loss_parameter["loss_property"]) - self.val_error = create_error_metrics(loss_parameter["loss_property"]) - self.train_error = create_error_metrics(loss_parameter["loss_property"]) + # Determine which properties to include based on loss components + self.include_force = ( + "per_atom_force" in training_parameter.loss_parameter.loss_components + ) - def _get_forces( - self, batch: "BatchData", energies: Dict[str, torch.Tensor] - ) -> Dict[str, torch.Tensor]: - """ - Computes the forces from a given batch using the model. + # Initialize the property calculation utility + self.calculate_predictions = CalculateProperties( + training_parameter.loss_parameter.loss_components + ) + self.optimizer_class = optimizer_class + self.learning_rate = training_parameter.lr + self.lr_scheduler = training_parameter.lr_scheduler - Parameters - ---------- - batch : BatchData - A single batch of data, including input features and target energies. - energies : dict - Dictionary containing predicted energies. + # Setup logging flags based on verbosity + self.log_histograms = training_parameter.verbose + self.log_on_training_step = training_parameter.verbose - Returns - ------- - Dict[str, torch.Tensor] - The true forces from the dataset and the predicted forces by the model. - """ - nnp_input = batch.nnp_input - per_atom_force_true = batch.metadata.F.to(torch.float32) + # Initialize the loss function with scheduled weights + weights_scheduling = self._setup_weights_scheduling( + training_parameter=training_parameter, + ) + self.loss = LossFactory.create_loss( + loss_components=training_parameter.loss_parameter.loss_components, + weights_scheduling=weights_scheduling, + ) - if per_atom_force_true.numel() < 1: - raise RuntimeError("No force can be calculated.") + # Initialize performance metrics for different phases + self.test_metrics = create_error_metrics( + training_parameter.loss_parameter.loss_components + ) + self.val_metrics = create_error_metrics( + training_parameter.loss_parameter.loss_components + ) + self.train_metrics = create_error_metrics( + training_parameter.loss_parameter.loss_components + ) - per_molecule_energy_predict = energies["per_molecule_energy_predict"] + self.loss_metrics = create_error_metrics( + training_parameter.loss_parameter.loss_components, is_loss=True + ) - # Ensure E_predict and nnp_input.positions require gradients and are on the same device - if not per_molecule_energy_predict.requires_grad: - per_molecule_energy_predict.requires_grad = True - if not nnp_input.positions.requires_grad: - nnp_input.positions.requires_grad = True + # Initialize dictionaries to store predictions and targets + self.train_preds: Dict[str, Dict[int, torch.Tensor]] = { + "energy": {}, + "force": {}, + } + self.train_targets: Dict[str, Dict[int, torch.Tensor]] = { + "energy": {}, + "force": {}, + } + self.val_preds: Dict[str, Dict[int, torch.Tensor]] = { + "energy": {}, + "force": {}, + } + self.val_targets: Dict[str, Dict[int, torch.Tensor]] = { + "energy": {}, + "force": {}, + } + self.test_preds: Dict[str, Dict[int, torch.Tensor]] = { + "energy": {}, + "force": {}, + } + self.test_targets: Dict[str, Dict[int, torch.Tensor]] = { + "energy": {}, + "force": {}, + } - # Compute the gradient (forces) from the predicted energies - grad = torch.autograd.grad( - per_molecule_energy_predict.sum(), - nnp_input.positions, - create_graph=False, - retain_graph=True, - )[0] - per_atom_force_predict = -1 * grad # Forces are the negative gradient of energy + # Initialize indices for validation and testing NOTE: this indices map + # back to the dataset + self.val_indices: Dict[int, torch.Tensor] = {} + self.test_indices: Dict[int, torch.Tensor] = {} + self.train_indices: Dict[int, torch.Tensor] = {} - return { - "per_atom_force_true": per_atom_force_true, - "per_atom_force_predict": per_atom_force_predict, - } + # Track outlier errors over epochs + self.outlier_errors_over_epochs: Dict[str, int] = {} + self.number_of_training_batches = nr_of_training_batches - def _get_energies(self, batch: "BatchData") -> Dict[str, torch.Tensor]: + def _setup_weights_scheduling( + self, training_parameter: TrainingParameters + ) -> Dict[str, torch.Tensor]: """ - Computes the energies from a given batch using the model. + Setup weight scheduling for loss components over epochs. Parameters ---------- - batch : BatchData - A single batch of data, including input features and target energies. + training_parameter : TrainingParameters + The training configuration. Returns ------- Dict[str, torch.Tensor] - The true energies from the dataset and the predicted energies by the model. + A dictionary mapping loss component names to their scheduled + weights. """ - nnp_input = batch.nnp_input - per_molecule_energy_true = batch.metadata.E.to(torch.float32) - per_molecule_energy_predict = self.model.forward(nnp_input)[ - "per_molecule_energy" - ].unsqueeze( - 1 - ) # FIXME: ensure that all per-molecule properties have dimension (N, 1) - assert per_molecule_energy_true.shape == per_molecule_energy_predict.shape, ( - f"Shapes of true and predicted energies do not match: " - f"{per_molecule_energy_true.shape} != {per_molecule_energy_predict.shape}" - ) - return { - "per_molecule_energy_true": per_molecule_energy_true, - "per_molecule_energy_predict": per_molecule_energy_predict, - } - def _get_predictions(self, batch: "BatchData") -> Dict[str, torch.Tensor]: + weights_scheduling: Dict[str, torch.Tensor] = {} + initial_weights = training_parameter.loss_parameter.weight + nr_of_epochs = training_parameter.number_of_epochs + + for key, initial_weight in initial_weights.items(): + target_weight = training_parameter.loss_parameter.target_weight[key] + mixing_steps = training_parameter.loss_parameter.mixing_steps[key] + + # Create a linear schedule from initial to target weight + mixing_scheme = torch.arange( + start=initial_weight, + end=target_weight, + step=mixing_steps, + ) + assert ( + len(mixing_scheme) < nr_of_epochs + ), "The number of epochs is less than the number of steps in the weight scheduling" + + # Fill up the rest of the epochs with the target weight + weights_scheduling[key] = torch.cat( + [ + mixing_scheme, + torch.ones(nr_of_epochs - mixing_scheme.shape[0]) * target_weight, + ] + ) + assert ( + weights_scheduling[key].shape[0] == nr_of_epochs + ), "Weight scheduling length mismatch." + return weights_scheduling + + def forward(self, batch: BatchData) -> Dict[str, torch.Tensor]: """ - Computes the energies and forces from a given batch using the model. + Forward pass to compute energies, forces, and other properties from a + batch. Parameters ---------- batch : BatchData - A single batch of data, including input features and target energies. + A batch of data including input features and target properties. Returns ------- Dict[str, torch.Tensor] - The true and predicted energies and forces from the dataset and the model. + Dictionary of predicted properties (energies, forces, etc.). """ - energies = self._get_energies(batch) - forces = self._get_forces(batch, energies) - return {**energies, **forces} + return self.potential(batch) def config_prior(self): """ Configures model-specific priors if the model implements them. """ - if hasattr(self.model, "_config_prior"): - return self.model._config_prior() + if hasattr(self.potential, "_config_prior"): + return self.potential._config_prior() log.warning("Model does not implement _config_prior().") raise NotImplementedError() + @staticmethod def _update_metrics( - self, - error_dict: Dict[str, torchmetrics.MetricCollection], + metrics: ModuleDict, predict_target: Dict[str, torch.Tensor], ): """ - Updates the provided metric collections with the predicted and true targets. + Updates the provided metric collections with the predicted and true + targets. Parameters ---------- - error_dict : dict - Dictionary containing metric collections for energy and force. - predict_target : dict - Dictionary containing predicted and true values for energy and force. - - Returns - ------- - Dict[str, torch.Tensor] - Dictionary containing updated metrics. + metrics : ModuleDict + Metric collections for energy and force evaluation. + predict_target : Dict[str, torch.Tensor] + Dictionary containing predicted and true values for properties. """ - for property, metrics in error_dict.items(): - for _, error_log in metrics.items(): - error_log( - predict_target[f"{property}_predict"].detach(), - predict_target[f"{property}_true"].detach(), - ) - - def training_step(self, batch: "BatchData", batch_idx: int) -> torch.Tensor: + for prop, metric_collection in metrics.items(): + prop = _exchange_per_atom_energy_for_per_system_energy( + prop + ) # only exchange per_atom_energy for per_system_energy + preds = predict_target[f"{prop}_predict"].detach() + targets = predict_target[f"{prop}_true"].detach() + metric_collection.update(preds, targets) + + def on_validation_epoch_start(self): + """Reset validation metrics at the start of the validation epoch.""" + self._reset_metrics(self.val_metrics) + + def on_test_epoch_start(self): + """Reset test metrics at the start of the test epoch.""" + self._reset_metrics(self.test_metrics) + + def _reset_metrics(self, metrics: ModuleDict): + """Utility function to reset all metrics in a ModuleDict.""" + for metric_collection in metrics.values(): + for metric in metric_collection.values(): + metric.reset() + + def training_step( + self, + batch: BatchData, + batch_idx: int, + ) -> torch.Tensor: """ Training step to compute the MSE loss for a given batch. @@ -558,95 +643,718 @@ def training_step(self, batch: "BatchData", batch_idx: int) -> torch.Tensor: The loss tensor computed for the current training step. """ - # calculate energy and forces - predict_target = self._get_predictions(batch) + # Calculate predictions based on the current batch + predict_target = self.calculate_predictions( + batch, self.potential, self.training + ) + + # Compute loss using the loss factory + loss_dict = self.loss( + predict_target, + batch, + self.current_epoch, + ) # Contains per-sample losses + + # Update loss metrics with per-sample losses + batch_size = batch.batch_size() + for key, metric in loss_dict.items(): + self.loss_metrics[key].update(metric.detach(), batch_size=batch_size) + + # Compute and log gradient norms for each loss component + if self.training_parameter.log_norm: + if key == "total_loss": + continue # Skip total loss for gradient norm logging + grad_norm = compute_grad_norm(metric.mean(), self) + self.log(f"grad_norm/{key}", grad_norm, sync_dist=True) + + # Save energy predictions and targets + self._update_predictions( + predict_target, + self.train_preds, + self.train_targets, + self.train_indices, + batch_idx, + batch, + ) + + # Compute the mean loss for optimization + total_loss = loss_dict["total_loss"].mean() + return total_loss - # calculate the loss - loss_dict = self.loss(predict_target, batch) + def validation_step(self, batch: BatchData, batch_idx: int) -> None: + """ + Validation step to compute validation loss and metrics. + """ - # Update and log training error - self._update_metrics(self.train_error, predict_target) + # Ensure positions require gradients for force calculation + batch.nnp_input.positions.requires_grad_(True) + with torch.set_grad_enabled(True): + # calculate energy and forces + predict_target = self.calculate_predictions( + batch, self.potential, self.potential.training + ) - # log the loss (this includes the individual contributions that the loss contains) - for key, loss in loss_dict.items(): - self.log( - f"loss/{key}", - torch.mean(loss), - on_step=False, - prog_bar=True, - on_epoch=True, - batch_size=1, - ) # batch size is 1 because the mean of the batch is logged + # Update validation metrics + self._update_metrics(self.val_metrics, predict_target) + + # Save energy predictions and targets + self._update_predictions( + predict_target, + self.val_preds, + self.val_targets, + self.val_indices, + batch_idx, + batch, + ) - return loss_dict["total_loss"] + def test_step(self, batch: BatchData, batch_idx: int) -> None: + """ + Test step to compute the test loss and metrics. + """ + # Ensure positions require gradients for force calculation + batch.nnp_input.positions.requires_grad_(True) + with torch.set_grad_enabled(True): + # calculate energy and forces + predict_target = self.calculate_predictions( + batch, self.potential, self.training + ) + # Update and log metrics + self._update_metrics(self.test_metrics, predict_target) + + # Save energy predictions and targets + self._update_predictions( + predict_target, + self.test_preds, + self.test_targets, + self.test_indices, + batch_idx, + batch, + ) - @torch.enable_grad() - def validation_step(self, batch: "BatchData", batch_idx: int) -> None: + def _update_predictions( + self, + predict_target: Dict[str, torch.Tensor], + preds: Dict[str, Dict[int, torch.Tensor]], + targets: Dict[str, Dict[int, torch.Tensor]], + indices: Dict[int, torch.Tensor], + batch_idx: int, + batch: BatchData, + ): """ - Validation step to compute the RMSE/MAE across epochs. + Update the predictions and targets dictionaries with the provided data. Parameters ---------- - batch : BatchData - The batch of data provided for validation. + predict_target : Dict[str, torch.Tensor] + The predicted and true values for properties. + preds : Dict[str, Dict[int, torch.Tensor]] + Dictionary to store predictions. + targets : Dict[str, Dict[int, torch.Tensor]] + Dictionary to store targets. + indices : Dict[int, torch.Tensor] + Dictionary to store indices referencing the dataset. batch_idx : int The index of the current batch. - - Returns - ------- - None + batch : BatchData + The current batch of data. """ + # Update energy predictions and targets + preds["energy"].update( + {batch_idx: predict_target["per_system_energy_predict"].detach().cpu()} + ) + targets["energy"].update( + {batch_idx: predict_target["per_system_energy_true"].detach().cpu()} + ) + # Save dataset indices + indices.update( + { + batch_idx: batch.metadata.atomic_subsystem_indices_referencing_dataset.detach().cpu() + } + ) + # Save force predictions and targets if forces are included + if "per_atom_force_predict" in predict_target: + preds["force"].update( + {batch_idx: predict_target["per_atom_force_predict"].detach().cpu()} + ) + targets["force"].update( + {batch_idx: predict_target["per_atom_force_true"].detach().cpu()} + ) - # Ensure positions require gradients for force calculation - batch.nnp_input.positions.requires_grad_(True) - # calculate energy and forces - predict_target = self._get_predictions(batch) - # calculate the loss - loss = self.loss(predict_target, batch) - # log the loss - self._update_metrics(self.val_error, predict_target) - - @torch.enable_grad() - def test_step(self, batch: "BatchData", batch_idx: int) -> None: + def _get_energy_tensors( + self, + preds: Dict[int, torch.Tensor], + targets: Dict[int, torch.Tensor], + indices: Dict[int, torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]: """ - Test step to compute the RMSE loss for a given batch. - - This method is called automatically during the test loop of the training process. It computes - the loss on a batch of test data and logs the results for analysis. + Gathers and pads prediction and target tensors across processes. Parameters ---------- - batch : BatchData - The batch of data to test the model on. - batch_idx : int - The index of the batch within the test dataset. + preds : Dict[int, torch.Tensor] + Dictionary of predictions from different batches. + targets : Dict[int, torch.Tensor] + Dictionary of targets from different batches. + indices : Dict[int, torch.Tensor] + Dictionary of indices from different batches. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int] + Gathered predictions, targets, indices, maximum length, and padding size. + """ + # Concatenate the tensors + preds_tensor = torch.cat(list(preds.values())) + targets_tensor = torch.cat(list(targets.values())) + indices_tensor = torch.cat(list(indices.values())).unique() + + # Get maximum length across all processes + local_length = torch.tensor([preds_tensor.size(0)], device=preds_tensor.device) + max_length = int(self.all_gather(local_length).max()) + + pad_size = max_length - preds_tensor.size(0) + if pad_size > 0: + log.debug(f"Padding tensors to the same length: {max_length}") + log.debug(f"Triggered at device: {self.global_rank}") + preds_tensor = torch.nn.functional.pad(preds_tensor, (0, pad_size)) + targets_tensor = torch.nn.functional.pad(targets_tensor, (0, pad_size)) + + # Gather across processes + gathered_preds = self.all_gather(preds_tensor) + gathered_targets = self.all_gather(targets_tensor) + gathered_indices = self.all_gather(indices_tensor) + + return gathered_preds, gathered_targets, gathered_indices, max_length, pad_size + + def _get_force_tensors( + self, preds: Dict[int, torch.Tensor], targets: Dict[int, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, int, int]: + """ + Gathers and pads force prediction and target tensors across processes. + + Parameters + ---------- + preds : Dict[int, torch.Tensor] + Dictionary of force predictions from different batches. + targets : Dict[int, torch.Tensor] + Dictionary of force targets from different batches. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, int, int] + Gathered force predictions, targets, maximum length, and padding size. + """ + # Concatenate the tensors + preds_tensor = torch.cat(list(preds.values())) + targets_tensor = torch.cat(list(targets.values())) + + # Get maximum length across all processes + local_length = torch.tensor([preds_tensor.size(0)], device=preds_tensor.device) + max_length = int(self.all_gather(local_length).max()) + + pad_size = max_length - preds_tensor.size(0) + if pad_size > 0: + log.debug(f"Padding force tensors to the same length: {max_length}") + log.debug(f"Triggered at device: {self.global_rank}") + # For forces, pad the last dimension (x, y, z) + preds_tensor = torch.nn.functional.pad(preds_tensor, (0, 0, 0, pad_size)) + targets_tensor = torch.nn.functional.pad( + targets_tensor, (0, 0, 0, pad_size) + ) + # Gather across processes + gathered_preds = self.all_gather(preds_tensor) + gathered_targets = self.all_gather(targets_tensor) + + return gathered_preds, gathered_targets, max_length, pad_size + + def _log_force_errors( + self, + preds: Dict[str, Dict[int, torch.Tensor]], + targets: Dict[str, Dict[int, torch.Tensor]], + indices: Dict[int, torch.Tensor], + phase: str, + ): + """ + Log the force error statistics as histograms. + + Parameters + ---------- + preds : Dict[str, Dict[int, torch.Tensor]] + Dictionary of force predictions. + targets : Dict[str, Dict[int, torch.Tensor]] + Dictionary of force targets. + indices : Dict[int, torch.Tensor] + Dictionary of indices referencing the dataset. + phase : str + The phase name ('train', 'val', or 'test'). + """ + + # Gather tensors + gathered_preds, gathered_targets, max_length, pad_size = ( + self._get_force_tensors( + preds["force"], + targets["force"], + ) + ) + + if self.global_rank == 0: + # Remove padding + total_length = max_length * self.trainer.world_size + gathered_preds = gathered_preds.reshape(total_length, 3)[ + : total_length - pad_size * self.trainer.world_size + ] + gathered_targets = gathered_targets.reshape(total_length, 3)[ + : total_length - pad_size * self.trainer.world_size + ] + errors = gathered_targets - gathered_preds + errors_magnitude = errors.norm(dim=1) # Compute magnitude of force errors + # make sure that errors are finite + assert torch.all( + torch.isfinite(errors_magnitude) + ), "Force errors contain NaN or Inf values." + + # Create histogram + histogram_fig = self._create_error_histogram( + errors_magnitude, + title=f"{phase.capitalize()} Magnitude of Force Error Histogram - Epoch {self.current_epoch}", + ) + + self._log_plots(phase, None, histogram_fig, force=True) + + def _log_plots(self, phase: str, regression_fig, histogram_fig, force=False): + """ + Log the regression and error histogram plots for the given phase. + + Parameters + ---------- + phase : str + The phase name ('train', 'val', or 'test'). + regression_fig : matplotlib.figure.Figure + The regression plot figure. + histogram_fig : matplotlib.figure.Figure + The error histogram figure. + force : bool, optional + Whether to indicate force-related plots (default is False). Returns ------- None - The results are logged and not directly returned. """ - # Ensure positions require gradients for force calculation - batch.nnp_input.positions.requires_grad_(True) - # calculate energy and forces - predict_target = self._get_predictions(batch) - # Update and log metrics - self._update_metrics(self.test_error, predict_target) - def on_test_epoch_end(self): + logger_name = self.training_parameter.experiment_logger.logger_name.lower() + plot_frequency = ( + self.training_parameter.plot_frequency + ) # how often to log plots + + if logger_name == "wandb": + import wandb + + # Log only every nth epoch for validation, but always log for test + if phase == "test" or self.current_epoch % plot_frequency == 0: + # NOTE: only log every nth epoch for validation, but always log + # for test + if not force and regression_fig is not None: + # Log histogram of errors and regression plot + self.logger.experiment.log( + {f"{phase}/regression_plot": wandb.Image(regression_fig)}, + # step=self.current_epoch + ) + self.logger.experiment.log( + { + f"{phase}/{'force_' if force else 'energy_'}error_histogram": wandb.Image( + histogram_fig + ) + }, + # step=self.current_epoch + ) + + elif logger_name == "tensorboard": + # Similar adjustments for tensorboard + if phase == "test" or self.current_epoch % plot_frequency == 0: + if not force and regression_fig is not None: + self.logger.experiment.add_figure( + f"{phase}_regression_plot_epoch_{self.current_epoch}", + regression_fig, + self.current_epoch, + ) + self.logger.experiment.add_figure( + f"{phase}_{'force_' if force else ''}error_histogram_epoch_{self.current_epoch}", + histogram_fig, + self.current_epoch, + ) + else: + log.warning(f"No logger found to log {phase} plots") + + import matplotlib.pyplot as plt + + # Close the figures + if regression_fig is not None: + plt.close(regression_fig) + plt.close(histogram_fig) + + def _create_regression_plot( + self, targets: torch.Tensor, predictions: torch.Tensor, title="Regression Plot" + ): + """ + Creates a regression plot comparing true targets and predictions. + + Parameters + ---------- + targets : torch.Tensor + Array of true target values. + predictions : torch.Tensor + Array of predicted values. + title : str, optional + Title of the plot. Default is 'Regression Plot'. + + Returns + ------- + matplotlib.figure.Figure + The regression plot figure. """ - Operations to perform at the end of the test set pass. + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + targets = targets.cpu().numpy() + predictions = predictions.cpu().numpy() + ax.scatter(targets, predictions, alpha=0.5) + ax.plot([targets.min(), targets.max()], [targets.min(), targets.max()], "r--") + ax.set_xlabel("True Values") + ax.set_ylabel("Predicted Values") + ax.set_title(title) + return fig + + def _create_error_histogram(self, errors: torch.Tensor, title="Error Histogram"): """ - self._log_on_epoch(log_mode="test") + Create an error histogram plot. + + Parameters + ---------- + errors : torch.Tensor + Tensor of error magnitudes. + title : str, optional + Title of the histogram (default is 'Error Histogram'). + + Returns + ------- + matplotlib.figure.Figure + The error histogram figure. + """ + import matplotlib.pyplot as plt + import numpy as np + + errors_np = errors.cpu().numpy().flatten() + + # Compute mean and standard deviation + mean_error = np.mean(errors_np) + std_error = np.std(errors_np) + + fig, ax = plt.subplots(figsize=(8, 6)) + bins = 50 + + # Plot histogram and get bin data + counts, bin_edges, patches = ax.hist( + errors_np, bins=bins, alpha=0.75, edgecolor="black" + ) + + # Set y-axis to log scale + ax.set_yscale("log") + + # Highlight outlier bins beyond 3 standard deviations + for count, edge_left, edge_right, patch in zip( + counts, bin_edges[:-1], bin_edges[1:], patches + ): + if (edge_left < mean_error - 3 * std_error) or ( + edge_right > mean_error + 3 * std_error + ): + patch.set_facecolor("red") + else: + patch.set_facecolor("blue") + + # Add vertical lines for mean and standard deviations + ax.axvline(mean_error, color="k", linestyle="dashed", linewidth=1, label="Mean") + ax.axvline( + mean_error + 3 * std_error, + color="r", + linestyle="dashed", + linewidth=1, + label="±3 Std Dev", + ) + ax.axvline( + mean_error - 3 * std_error, color="r", linestyle="dashed", linewidth=1 + ) + + ax.set_xlabel("Error") + ax.set_ylabel("Frequency (Log Scale)") + ax.set_title(title) + ax.legend() + + return fig + + def _log_figures_for_each_phase( + self, + preds: Dict[str, Dict[int, torch.Tensor]], + targets: Dict[str, Dict[int, torch.Tensor]], + indices: Dict[int, torch.Tensor], + phase: Literal["train", "val", "test"], + ): + """ + Log regression plots and error histograms for a specific phase. + + Parameters + ---------- + preds : Dict[str, Dict[int, torch.Tensor]] + Dictionary of predictions. + targets : Dict[str, Dict[int, torch.Tensor]] + Dictionary of targets. + indices : Dict[int, torch.Tensor] + Dictionary of dataset indices. + phase : Literal["train", "val", "test"] + The phase name. + + """ + # Gather across processes + gathered_preds, gathered_targets, gathered_indices, max_length, pad_size = ( + self._get_energy_tensors( + preds["energy"], + targets["energy"], + indices, + ) + ) + + # Proceed only on main process + if self.global_rank == 0: + # Remove padding + total_length = max_length * self.trainer.world_size + gathered_preds = gathered_preds.reshape(total_length)[ + : total_length - pad_size * self.trainer.world_size + ] + gathered_targets = gathered_targets.reshape(total_length)[ + : total_length - pad_size * self.trainer.world_size + ] + gathered_indices = gathered_indices.reshape(total_length)[ + : total_length - pad_size * self.trainer.world_size + ] + + errors = gathered_targets - gathered_preds + if errors.size == 0: + log.warning("Errors array is empty.") + + # Create regression plot + regression_fig = self._create_regression_plot( + gathered_targets, + gathered_preds, + title=f"{phase.capitalize()} Regression Plot - Epoch {self.current_epoch}", + ) + + # Generate error histogram plot + histogram_fig = self._create_error_histogram( + errors, + title=f"{phase.capitalize()} Error Histogram - Epoch {self.current_epoch}", + ) + self._log_plots(phase, regression_fig, histogram_fig) + + # Log outlier error counts for non-training phases + if phase != "train": + self._identify__and_log_top_k_errors(errors, gathered_indices, phase) + self.log_dict( + self.outlier_errors_over_epochs, on_epoch=True, rank_zero_only=True + ) + + def _identify__and_log_top_k_errors( + self, + errors: torch.Tensor, + indices: torch.Tensor, + phase: Literal["train", "val", "test"], + k: int = 3, + ): + """ + Identify and log the top k largest errors. + + Parameters + ---------- + errors : torch.Tensor + Tensor of error magnitudes. + indices : torch.Tensor + Tensor of dataset indices corresponding to the errors. + phase : Literal["train", "val", "test"] + The phase name. + k : int, optional + Number of top errors to track (default is 3). + + """ + + # Compute absolute errors + abs_errors = torch.abs(errors).detach().cpu() + # Flatten tensors + abs_errors = abs_errors.flatten() + indices = indices.flatten().long().detach().cpu() + + # Get top k errors and their corresponding indices + top_k_errors, top_k_indices = torch.topk(abs_errors, k) + + top_k_indices = indices[top_k_indices].tolist() + for idx, error in zip(top_k_indices, top_k_errors.tolist()): + key = f"outlier_count/{phase}/{idx}" + if key not in self.outlier_errors_over_epochs: + self.outlier_errors_over_epochs[key] = 0 + self.outlier_errors_over_epochs[key] += 1 + log.info( + f"{self.current_epoch}: {phase} : Outlier error {error} at index {idx}." + ) + + def _clear_error_tracking(self, preds, targets, incides): + """ + Clear the prediction, target, and index tracking dictionaries. + + Parameters + ---------- + preds : Dict[str, Dict[int, torch.Tensor]] + Dictionary of predictions. + targets : Dict[str, Dict[int, torch.Tensor]] + Dictionary of targets. + indices : Dict[int, torch.Tensor] + Dictionary of dataset indices. + + """ + for d in [preds, targets]: + d["energy"].clear() + d["force"].clear() + incides.clear() + + def on_test_epoch_end(self): + """Logs metrics and figures at the end of the test epoch.""" + self._log_metrics(self.test_metrics, "test") + self._log_figures_for_each_phase( + self.test_preds, + self.test_targets, + self.test_indices, + "test", + ) + # Clear the dictionaries after logging + self._clear_error_tracking( + self.test_preds, + self.test_targets, + self.test_indices, + ) + + def on_validation_epoch_end(self): + """Logs metrics and figures at the end of the validation epoch.""" + self._log_metrics(self.val_metrics, "val") + self._log_figures_for_each_phase( + self.val_preds, + self.val_targets, + self.val_indices, + "val", + ) + # Clear the dictionaries after logging + self._clear_error_tracking( + self.val_preds, + self.val_targets, + self.val_indices, + ) + + def on_train_epoch_start(self): + """Start the epoch timer.""" + self.epoch_start_time = time.time() + + def _log_time(self): + """Log the time taken per epoch to W&B.""" + epoch_time = time.time() - self.epoch_start_time + if isinstance(self.logger, pL.loggers.WandbLogger): + # Log epoch duration to W&B + self.logger.experiment.log( + {"epoch_time": epoch_time, "epoch": self.current_epoch} + ) + else: + log.warning("Weights & Biases logger not found; epoch time not logged.") def on_train_epoch_end(self): + """Logs metrics, learning rate, histograms, and figures at the end of the training epoch.""" + self._log_metrics(self.loss_metrics, "loss") + # this performs gather operations and logs only at rank == 0 + self._log_figures_for_each_phase( + self.train_preds, + self.train_targets, + self.train_indices, + "train", + ) + if self.include_force: + self._log_force_errors( + self.train_preds, + self.train_targets, + self.train_indices, + "train", + ) + # Clear the dictionaries after logging + self._clear_error_tracking( + self.train_preds, + self.train_targets, + self.train_indices, + ) + + self._log_learning_rate() + self._log_time() + self._log_histograms() + # log the weights of the different loss components + if self.trainer.is_global_zero: + for key, weight in self.loss.weights_scheduling.items(): + self.log( + f"loss/{key}/weight", + weight[self.current_epoch], + rank_zero_only=True, + ) + + def _log_learning_rate(self): + """Logs the current learning rate.""" + sch = self.lr_schedulers() + if self.trainer.is_global_zero: + try: + self.log( + "lr", + sch.get_last_lr()[0], + on_epoch=True, + prog_bar=True, + rank_zero_only=True, + ) + except AttributeError: + pass + + def _log_metrics(self, metrics: ModuleDict, phase: str): """ - Operations to perform at the end of each training epoch. + Log all accumulated metrics for a given phase. + + Parameters + ---------- + metrics : ModuleDict + The metrics to log. + phase : str + The phase name ('train', 'val', or 'test'). + + """ + # abbreviate long names to shorter versions + abbreviate = { + "MeanAbsoluteError": "mae", + "MeanSquaredError": "rmse", + "MeanMetric": "mse", # NOTE: MeanMetric is the MSE since we accumulate the squared error + } # NOTE: MeanSquaredError(squared=False) is RMSE - Logs histograms of weights and biases, and learning rate. - Also, resets validation loss. + for prop, metric_collection in metrics.items(): + for metric_name, metric in metric_collection.items(): + metric_value = metric.compute() + metric.reset() + self.log( + f"{phase}/{prop}/{abbreviate[metric_name]}", + metric_value, + prog_bar=True, + sync_dist=True, + ) + + def _log_histograms(self): + """ + Log histograms of model parameters and their gradients if enabled. """ - if self.log_histograms == True: + if self.log_histograms: for name, params in self.named_parameters(): if params is not None: self.logger.experiment.add_histogram( @@ -657,477 +1365,780 @@ def on_train_epoch_end(self): f"{name}.grad", params.grad, self.current_epoch ) - sch = self.lr_schedulers() - try: - self.log("lr", sch.get_last_lr()[0], on_epoch=True, prog_bar=True) - except AttributeError: - pass + def configure_optimizers(self): + """ + Configure the optimizers and learning rate schedulers. - self._log_on_epoch() + Returns + ------- + Dict[str, Any] + A dictionary containing the optimizer and optionally the scheduler. + """ + from modelforge.train.parameters import ( + ReduceLROnPlateauConfig, + CosineAnnealingLRConfig, + CosineAnnealingWarmRestartsConfig, + OneCycleLRConfig, + CyclicLRConfig, + ) - def _log_on_epoch(self, log_mode: str = "train"): - # convert long names to shorter versions - conv = { - "MeanAbsoluteError": "mae", - "MeanSquaredError": "rmse", - } # NOTE: MeanSquaredError(squared=False) is RMSE + # Separate parameters into weight and bias groups + weight_params = [] + bias_params = [] + + for name, param in self.potential.named_parameters(): + if ( + "weight" in name + or "atomic_shift" in name + or "gate" in name + or "agh" in name + ): + weight_params.append(param) + elif "bias" in name or "atomic_scale" in name: + bias_params.append(param) + else: + # If parameter type is unknown, raise an error + raise ValueError(f"Unknown parameter type: {name}") + + # Define parameter groups with different weight decay + param_groups = [ + { + "params": weight_params, + "lr": self.learning_rate, + "weight_decay": 1e-3, # Apply weight decay to weights + }, + { + "params": bias_params, + "lr": self.learning_rate, + "weight_decay": 0.0, # No weight decay for biases + }, + ] + + optimizer = torch.optim.AdamW(param_groups) + + lr_scheduler_config = self.lr_scheduler + + if lr_scheduler_config is None: + return {"optimizer": optimizer} + + interval = lr_scheduler_config.interval + frequency = lr_scheduler_config.frequency + monitor = ( + lr_scheduler_config.monitor or self.monitor + ) # Use default monitor if not specified + + # Determine the scheduler class and parameters + if isinstance(lr_scheduler_config, ReduceLROnPlateauConfig): + scheduler_class = ReduceLROnPlateau + scheduler_params = lr_scheduler_config.model_dump( + exclude={"scheduler_name", "frequency", "interval", "monitor"} + ) + elif isinstance(lr_scheduler_config, CosineAnnealingLRConfig): + scheduler_class = CosineAnnealingLR + scheduler_params = lr_scheduler_config.model_dump( + exclude={"scheduler_name", "frequency", "interval", "monitor"} + ) + elif isinstance(lr_scheduler_config, CosineAnnealingWarmRestartsConfig): + scheduler_class = CosineAnnealingWarmRestarts + scheduler_params = lr_scheduler_config.model_dump( + exclude={"scheduler_name", "frequency", "interval", "monitor"} + ) + elif isinstance(lr_scheduler_config, OneCycleLRConfig): + scheduler_class = OneCycleLR + scheduler_params = lr_scheduler_config.model_dump( + exclude={ + "scheduler_name", + "frequency", + "interval", + "monitor", + "steps_per_epoch", + "total_steps", + } + ) - # Log all accumulated metrics for train and val phases - if log_mode == "train": - errors = [ - ("train", self.train_error), - ("val", self.val_error), - ] - elif log_mode == "test": - errors = [ - ("test", self.test_error), - ] + # Calculate steps_per_epoch + steps_per_epoch = self.number_of_training_batches + scheduler_params["steps_per_epoch"] = steps_per_epoch + scheduler_params["epochs"] = lr_scheduler_config.epochs + elif isinstance(lr_scheduler_config, CyclicLRConfig): + scheduler_class = CyclicLR + scheduler_params = lr_scheduler_config.model_dump( + exclude={ + "scheduler_name", + "frequency", + "interval", + "monitor", + "epochs_up", + "epochs_down", + } + ) + + # Calculate steps_per_epoch + steps_per_epoch = self.number_of_training_batches + + # Calculate step_size_up and step_size_down + epochs_up = lr_scheduler_config.epochs_up + epochs_down = ( + lr_scheduler_config.epochs_down or epochs_up + ) # Symmetric cycle if not specified + step_size_up = int(epochs_up * steps_per_epoch) + step_size_down = int(epochs_down * steps_per_epoch) + + scheduler_params["step_size_up"] = step_size_up + scheduler_params["step_size_down"] = step_size_down else: - raise RuntimeError(f"Unrecognized mode: {log_mode}") + raise NotImplementedError( + f"Unsupported learning rate scheduler: {lr_scheduler_config.scheduler_name}" + ) + lr_scheduler_instance = scheduler_class(optimizer, **scheduler_params) + + scheduler = { + "scheduler": lr_scheduler_instance, + "monitor": monitor, # Name of the metric to monitor + "interval": interval, + "frequency": frequency, + } + + return {"optimizer": optimizer, "lr_scheduler": scheduler} + + +from openff.units import unit + + +class PotentialTrainer: + """ + Class for training neural network potentials using PyTorch Lightning. + """ + + def __init__( + self, + *, + dataset_parameter: DatasetParameters, + potential_parameter: T_NNP_Parameters, + training_parameter: TrainingParameters, + runtime_parameter: RuntimeParameters, + dataset_statistic: Dict[str, Dict[str, unit.Quantity]], + use_default_dataset_statistic: bool, + optimizer_class: Type[Optimizer] = torch.optim.AdamW, + potential_seed: Optional[int] = None, + verbose: bool = False, + ): + """ + Initializes the TrainingAdapter with the specified model and training + configuration. + + Parameters + ---------- + dataset_parameter : DatasetParameters + Parameters for the dataset. + potential_parameter : Union[ANI2xParameters, SAKEParameters, + SchNetParameters, PhysNetParameters, PaiNNParameters, + TensorNetParameters] + Parameters for the potential model. + training_parameter : TrainingParameters + Parameters for the training process. + runtime_parameter : RuntimeParameters + Parameters for runtime configuration. + dataset_statistic : Dict[str, Dict[str, unit.Quantity]] + Dataset statistics such as mean and standard deviation. + use_default_dataset_statistic: bool + Whether to use default dataset statistic + optimizer_class : Type[Optimizer], optional + The optimizer class to use for training, by default + torch.optim.AdamW. + potential_seed: Optional[int], optional + Seed to initialize the potential training adapter, default is None. + verbose : bool, optional + If True, enables verbose logging, by default False. + """ - for phase, error_dict in errors: - # skip if log_on_training_step is not requested - if phase == "train" and not self.log_on_training_step: - continue + super().__init__() - metrics = {} - for property, metrics_dict in error_dict.items(): - for name, metric in metrics_dict.items(): - name = f"{phase}/{property}/{conv[name]}" - metrics[name] = metric.compute() - metric.reset() - # log dict, print val metrics to console - self.log_dict(metrics, on_epoch=True, prog_bar=(phase == "val")) + # Assign parameters to instance variables + self.dataset_parameter = dataset_parameter + self.potential_parameter = potential_parameter + self.training_parameter = training_parameter + self.runtime_parameter = runtime_parameter + self.verbose = verbose + + # Setup data module + self.datamodule = self.setup_datamodule() + # Read and assign provided dataset statistics + self.dataset_statistic = ( + self.read_dataset_statistics() + if not use_default_dataset_statistic + else dataset_statistic + ) + self.experiment_logger = self.setup_logger() + self.callbacks = self.setup_callbacks() + self.trainer = self.setup_trainer() + self.optimizer_class = optimizer_class + self.lightning_module = self.setup_lightning_module(potential_seed) - def configure_optimizers(self) -> Dict[str, Any]: + def read_dataset_statistics( + self, + ) -> dict[str, dict[str, Any]]: """ - Configures the model's optimizers (and optionally schedulers). + Read and log dataset statistics. Returns ------- - Dict[str, Any] - A dictionary containing the optimizer and optionally the learning rate scheduler - to be used within the PyTorch Lightning training process. - """ - - optimizer = self.optimizer(self.model.parameters(), lr=self.learning_rate) - - lr_scheduler_config = self.lr_scheduler_config - lr_scheduler = ReduceLROnPlateau( - optimizer, - mode=lr_scheduler_config["mode"], - factor=lr_scheduler_config["factor"], - patience=lr_scheduler_config["patience"], - cooldown=lr_scheduler_config["cooldown"], - min_lr=lr_scheduler_config["min_lr"], - threshold=lr_scheduler_config["threshold"], - threshold_mode="abs", + Dict[str, float] + The dataset statistics. + """ + from modelforge.potential.utils import ( + convert_str_to_unit_in_dataset_statistics, + read_dataset_statistics, ) - lr_scheduler_config = { - "scheduler": lr_scheduler, - "monitor": lr_scheduler_config["monitor"], # Name of the metric to monitor - "interval": "epoch", - "frequency": 1, - } - return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} + # read toml file + dataset_statistic = read_dataset_statistics( + self.datamodule.dataset_statistic_filename + ) + # convert dictionary of str:str to str:units + dataset_statistic = convert_str_to_unit_in_dataset_statistics(dataset_statistic) + log.info( + f"Setting per_atom_energy_mean and per_atom_energy_stddev for {self.potential_parameter.potential_name}" + ) + log.info( + f"per_atom_energy_mean: {dataset_statistic['training_dataset_statistics']['per_atom_energy_mean']}" + ) + log.info( + f"per_atom_energy_stddev: {dataset_statistic['training_dataset_statistics']['per_atom_energy_stddev']}" + ) + return dataset_statistic + def setup_datamodule(self) -> DataModule: + """ + Set up the DataModule for the dataset. -def return_toml_config( - config_path: Optional[str] = None, - potential_path: Optional[str] = None, - dataset_path: Optional[str] = None, - training_path: Optional[str] = None, - runtime_path: Optional[str] = None, -): - """ - Read one or more TOML configuration files and return the parsed configuration. + Returns + ------- + DataModule + Configured DataModule instance. + """ + from modelforge.dataset.dataset import DataModule + from modelforge.dataset.utils import REGISTERED_SPLITTING_STRATEGIES + + dm = DataModule( + name=self.dataset_parameter.dataset_name, + batch_size=self.training_parameter.batch_size, + remove_self_energies=self.training_parameter.remove_self_energies, + shift_center_of_mass_to_origin=self.training_parameter.shift_center_of_mass_to_origin, + version_select=self.dataset_parameter.version_select, + local_cache_dir=self.runtime_parameter.local_cache_dir, + splitting_strategy=REGISTERED_SPLITTING_STRATEGIES[ + self.training_parameter.splitting_strategy.name + ]( + seed=self.training_parameter.splitting_strategy.seed, + split=self.training_parameter.splitting_strategy.data_split, + ), + regenerate_processed_cache=self.dataset_parameter.regenerate_processed_cache, + ) + dm.prepare_data() + dm.setup() + return dm - Parameters - ---------- - config_path : str, optional - The path to the TOML configuration file. - potential_path : str, optional - The path to the TOML file defining the potential configuration. - dataset_path : str, optional - The path to the TOML file defining the dataset configuration. - training_path : str, optional - The path to the TOML file defining the training configuration. - runtime_path : str, optional - The path to the TOML file defining the runtime configuration. + def setup_lightning_module( + self, potential_seed: Optional[int] = None + ) -> pL.LightningModule: + """ + Set up the model for training. - Returns - ------- - dict - The merged parsed configuration from the TOML files. - """ - import toml + Parameters + ---------- + potential_seed : int, optional + Seed to be used to initialize the potential, by default None. - config = {} + Returns + ------- + nn.Module + Configured model instance, wrapped in a TrainingAdapter. + """ - if config_path: - config = toml.load(config_path) - log.info(f"Reading config from : {config_path}") - else: - if potential_path: - config["potential"] = toml.load(potential_path)["potential"] - log.info(f"Reading potential config from : {potential_path}") - if dataset_path: - config["dataset"] = toml.load(dataset_path)["dataset"] - log.info(f"Reading dataset config from : {dataset_path}") - if training_path: - config["training"] = toml.load(training_path)["training"] - log.info(f"Reading training config from : {training_path}") - if runtime_path: - config["runtime"] = toml.load(runtime_path)["runtime"] - log.info(f"Reading runtime config from : {runtime_path}") - return config + # Initialize model + return TrainingAdapter( + potential_parameter=self.potential_parameter, + dataset_statistic=self.dataset_statistic, + training_parameter=self.training_parameter, + optimizer_class=self.optimizer_class, + nr_of_training_batches=len(self.datamodule.train_dataloader()), + potential_seed=potential_seed, + ) + def setup_logger(self) -> pL.loggers.Logger: + """ + Set up the experiment logger based on the configuration. -from typing import List, Optional, Union + Returns + ------- + pL.loggers.Logger + Configured logger instance. + """ + experiment_name = self._format_experiment_name( + self.runtime_parameter.experiment_name + ) -def read_config_and_train( - config_path: Optional[str] = None, - potential_path: Optional[str] = None, - dataset_path: Optional[str] = None, - training_path: Optional[str] = None, - runtime_path: Optional[str] = None, - accelerator: Optional[str] = None, - device: Optional[Union[int, List[int]]] = None, -): - """ - Reads one or more TOML configuration files and performs training based on the parameters. + if self.training_parameter.experiment_logger.logger_name == "tensorboard": + from lightning.pytorch.loggers import TensorBoardLogger - Parameters - ---------- - config_path : str, optional - Path to the TOML configuration file. - potential_path : str, optional - Path to the TOML file defining the potential configuration. - dataset_path : str, optional - Path to the TOML file defining the dataset configuration. - training_path : str, optional - Path to the TOML file defining the training configuration. - runtime_path : str, optional - Path to the TOML file defining the runtime configuration. - accelerator : str, optional - Accelerator type to use for training. - device : int|List[int], optional - Device index to use for training. - """ - # Read the TOML file - config = return_toml_config( - config_path, potential_path, dataset_path, training_path, runtime_path - ) + logger = ( + TensorBoardLogger( + save_dir=str( + self.training_parameter.experiment_logger.tensorboard_configuration.save_dir + ), + name=experiment_name, + ), + ) + elif self.training_parameter.experiment_logger.logger_name == "wandb": + from modelforge.utils.io import check_import + + check_import("wandb") + from lightning.pytorch.loggers import WandbLogger + + logger = WandbLogger( + save_dir=str( + self.training_parameter.experiment_logger.wandb_configuration.save_dir + ), + log_model=self.training_parameter.experiment_logger.wandb_configuration.log_model, + project=self.training_parameter.experiment_logger.wandb_configuration.project, + group=self.training_parameter.experiment_logger.wandb_configuration.group, + job_type=self.training_parameter.experiment_logger.wandb_configuration.job_type, + tags=self._generate_tags( + self.training_parameter.experiment_logger.wandb_configuration.tags + ), + notes=self.training_parameter.experiment_logger.wandb_configuration.notes, + name=experiment_name, + ) + else: + raise ValueError("Unsupported logger type.") + return logger - # Extract parameters - potential_config = config["potential"] - dataset_config = config["dataset"] - training_config = config["training"] - runtime_config = config["runtime"] - - # Override config parameters with command-line arguments if provided - if accelerator: - runtime_config["accelerator"] = accelerator - if device is not None: - runtime_config["devices"] = device - - log.debug(f"Potential config: {potential_config}") - log.debug(f"Dataset config: {dataset_config}") - log.debug(f"Training config: {training_config}") - log.debug(f"Runtime config: {runtime_config}") - - # Call the perform_training function with extracted parameters - perform_training( - potential_config=potential_config, - training_config=training_config, - dataset_config=dataset_config, - runtime_config=runtime_config, - ) + def setup_callbacks(self) -> List[Any]: + """ + Set up the callbacks for the trainer. + The callbacks include early stopping (optional), model checkpointing, and stochastic weight averaging (optional). -from lightning import Trainer + Returns + ------- + List[Any] + List of configured callbacks. + """ + from lightning.pytorch.callbacks import ( + EarlyStopping, + ModelCheckpoint, + StochasticWeightAveraging, + Callback, + ) + callbacks = [] + if self.training_parameter.stochastic_weight_averaging: + callbacks.append( + StochasticWeightAveraging( + **self.training_parameter.stochastic_weight_averaging.model_dump() + ) + ) -def log_training_arguments( - potential_config: Dict[str, Any], - training_config: Dict[str, Any], - dataset_config: Dict[str, Any], - runtime_config: Dict[str, Any], -): - """ - Log arguments that are passed to the training routine. - - Arguments - ---- - potential_config: Dict[str, Any] - config for the potential model - training_config: Dict[str, Any] - config for the training process - dataset_config: Dict[str, Any] - config for the dataset - runtime_config: Dict[str, Any] - config for the runtime - """ + if self.training_parameter.early_stopping: + callbacks.append( + EarlyStopping(**self.training_parameter.early_stopping.model_dump()) + ) - save_dir = runtime_config["save_dir"] - log.info(f"Saving logs to location: {save_dir}") + # Save the best model based on the validation loss + # NOTE: The filename is formatted as "best_{potential_name}-{dataset_name}-{epoch:02d}" + checkpoint_filename = ( + f"best_{self.potential_parameter.potential_name}-{self.dataset_parameter.dataset_name}" + + "-{epoch:02d}" + ) + callbacks.append( + ModelCheckpoint( + save_top_k=2, + monitor=self.training_parameter.monitor, + filename=checkpoint_filename, + ) + ) - experiment_name = runtime_config["experiment_name"] - log.info(f"Saving logs in dir: {experiment_name}") + # compute gradient norm + class GradNormCallback(Callback): + """ + Logs the gradient norm. + """ - version_select = dataset_config.get("version_select", "latest") - if version_select == "latest": - log.info(f"Using default dataset version: {version_select}") - else: - log.info(f"Using dataset version: {version_select}") + def on_before_optimizer_step(self, trainer, pl_module, optimizer): + pl_module.log("grad_norm/model", gradient_norm(pl_module)) - local_cache_dir = runtime_config.get("local_cache_dir", "./") - if local_cache_dir is None: - log.info(f"Using default cache directory: {local_cache_dir}") - else: - log.info(f"Using cache directory: {local_cache_dir}") + if self.training_parameter.log_norm: + callbacks.append(GradNormCallback()) - accelerator = runtime_config.get("accelerator", "cpu") - if accelerator == "cpu": - log.info(f"Using default accelerator: {accelerator}") - else: - log.info(f"Using accelerator: {accelerator}") - nr_of_epochs = training_config.get("nr_of_epochs", 10) - if nr_of_epochs == 10: - log.info(f"Using default number of epochs: {nr_of_epochs}") - else: - log.info(f"Training for {nr_of_epochs} epochs") - num_nodes = runtime_config.get("num_nodes", 1) - if num_nodes == 1: - log.info(f"Using default number of nodes: {num_nodes}") - else: - log.info(f"Training on {num_nodes} nodes") - devices = runtime_config.get("devices", 1) - if devices == 1: - log.info(f"Using default device index/number: {devices}") - else: - log.info(f"Using device index/number: {devices}") + return callbacks - batch_size = training_config.get("batch_size", 128) - if batch_size == 128: - log.info(f"Using default batch size: {batch_size}") - else: - log.info(f"Using batch size: {batch_size}") + def setup_trainer(self) -> Trainer: + """ + Set up the Trainer for training. - remove_self_energies = training_config.get("remove_self_energies", False) - if remove_self_energies is False: - log.warning( - f"Using default for removing self energies: Self energies are not removed" + Returns + ------- + Trainer + Configured Trainer instance. + """ + from lightning import Trainer + + # if devices is a list (but longer than 1) + if ( + isinstance(self.runtime_parameter.devices, list) + and len(self.runtime_parameter.devices) > 1 + ) or ( + isinstance(self.runtime_parameter.devices, int) + and self.runtime_parameter.devices > 1 + ): + from lightning.pytorch.strategies import DDPStrategy + + strategy = DDPStrategy(find_unused_parameters=True) + else: + strategy = "auto" + + trainer = Trainer( + strategy=strategy, + max_epochs=self.training_parameter.number_of_epochs, + min_epochs=self.training_parameter.min_number_of_epochs, + num_nodes=self.runtime_parameter.number_of_nodes, + devices=self.runtime_parameter.devices, + accelerator=self.runtime_parameter.accelerator, + logger=self.experiment_logger, + callbacks=self.callbacks, + benchmark=True, + inference_mode=False, + limit_train_batches=self.training_parameter.limit_train_batches, + limit_val_batches=self.training_parameter.limit_val_batches, + limit_test_batches=self.training_parameter.limit_test_batches, + num_sanity_val_steps=1, + gradient_clip_val=5.0, # FIXME: hardcoded for now + log_every_n_steps=self.runtime_parameter.log_every_n_steps, + enable_model_summary=True, + enable_progress_bar=self.runtime_parameter.verbose, # if true will show progress bar ) - else: - log.info(f"Removing self energies: {remove_self_energies}") + return trainer - splitting_strategy = training_config["splitting_strategy"]["name"] - data_split = training_config["splitting_strategy"]["data_split"] - log.info(f"Using splitting strategy: {splitting_strategy} with split: {data_split}") + def train_potential(self) -> Trainer: + """ + Run the training, validation, and testing processes. - early_stopping_config = training_config.get("early_stopping", None) - if early_stopping_config is None: - log.info(f"Using default: No early stopping performed") + Returns + ------- + Trainer + The configured trainer instance after running the training process. + """ + self.trainer.fit( + self.lightning_module, + train_dataloaders=self.datamodule.train_dataloader( + num_workers=self.dataset_parameter.num_workers, + pin_memory=self.dataset_parameter.pin_memory, + ), + val_dataloaders=self.datamodule.val_dataloader(), + ckpt_path=( + self.runtime_parameter.checkpoint_path + if self.runtime_parameter.checkpoint_path != "None" + else None + ), # NOTE: automatically resumes training from checkpoint + ) - stochastic_weight_averaging_config = training_config.get( - "stochastic_weight_averaging_config", None - ) + self.trainer.validate( + model=self.lightning_module, + dataloaders=self.datamodule.val_dataloader(), + ckpt_path="best", + verbose=True, + ) - num_workers = dataset_config.get("number_of_worker", 4) - if num_workers == 4: - log.info( - f"Using default number of workers for training data loader: {num_workers}" + self.trainer.test( + model=self.lightning_module, + dataloaders=self.datamodule.test_dataloader(), + ckpt_path="best", + verbose=True, ) - else: - log.info(f"Using {num_workers} workers for training data loader") + return self.trainer - pin_memory = dataset_config.get("pin_memory", False) - if pin_memory is False: - log.info(f"Using default value for pinned_memory: {pin_memory}") - else: - log.info(f"Using pinned_memory: {pin_memory}") - - model_name = potential_config["model_name"] - dataset_name = dataset_config["dataset_name"] - log.info(training_config["training_parameter"]["loss_parameter"]) - log.debug( - f""" -Training {model_name} on {dataset_name}-{version_select} dataset with {accelerator} -accelerator on {num_nodes} nodes for {nr_of_epochs} epochs. -Experiments are saved to: {save_dir}/{experiment_name}. -Local cache directory: {local_cache_dir} -""" - ) + def config_prior(self): + """ + Configures model-specific priors if the model implements them. + """ + if hasattr(self.lightning_module, "_config_prior"): + return self.lightning_module._config_prior() + + log.warning("Model does not implement _config_prior().") + raise NotImplementedError() + + def _format_experiment_name(self, experiment_name: str) -> str: + """ + Replace the placeholders in the experiment name with the actual values. + + Parameters + ---------- + experiment_name : str + The experiment name with placeholders. + + Returns + ------- + str + The experiment name with the placeholders replaced. + """ + # replace placeholders in the experiment name + experiment_name = experiment_name.replace( + "{potential_name}", self.potential_parameter.potential_name + ) + experiment_name = experiment_name.replace( + "{dataset_name}", self.dataset_parameter.dataset_name + ) + return experiment_name + + def _generate_tags(self, tags: List[str]) -> List[str]: + """Generates tags for the experiment.""" + import modelforge + + tags.extend( + [ + str(modelforge.__version__), + self.dataset_parameter.dataset_name, + self.potential_parameter.potential_name, + f"loss-{'-'.join(self.training_parameter.loss_parameter.loss_components)}", + ] + ) + return tags + + +from typing import List, Optional, Union -def perform_training( - potential_config: Dict[str, Any], - training_config: Dict[str, Any], - dataset_config: Dict[str, Any], - runtime_config: Dict[str, Any], +def read_config( + condensed_config_path: Optional[str] = None, + training_parameter_path: Optional[str] = None, + dataset_parameter_path: Optional[str] = None, + potential_parameter_path: Optional[str] = None, + runtime_parameter_path: Optional[str] = None, + accelerator: Optional[str] = None, + devices: Optional[Union[int, List[int]]] = None, + number_of_nodes: Optional[int] = None, + experiment_name: Optional[str] = None, + save_dir: Optional[str] = None, + local_cache_dir: Optional[str] = None, checkpoint_path: Optional[str] = None, -) -> Trainer: + log_every_n_steps: Optional[int] = None, + simulation_environment: Optional[str] = None, +): """ - Performs the training process for a neural network potential model. + Reads one or more TOML configuration files and loads them into the pydantic + models. Parameters ---------- - potential_config : Dict[str, Any], optional - Additional parameters for the potential model. - training_config : Dict[str, Any], optional - Additional parameters for the training process. - dataset_config : Dict[str, Any], optional - Additional parameters for the dataset. + condensed_config_path : Optional[str], optional + Path to the TOML configuration that contains all parameters for the + dataset, potential, training, and runtime parameters. Any other provided + configuration files will be ignored. + training_parameter_path : Optional[str], optional + Path to the TOML file defining the training parameters. + dataset_parameter_path : Optional[str], optional + Path to the TOML file defining the dataset parameters. + potential_parameter_path : Optional[str], optional + Path to the TOML file defining the potential parameters. + runtime_parameter_path : Optional[str], optional + Path to the TOML file defining the runtime parameters. If this is not + provided, the code will attempt to use the runtime parameters provided + as arguments. + accelerator : Optional[str], optional + Accelerator type to use. If provided, this overrides the accelerator + type in the runtime_defaults configuration. + devices : Optional[Union[int, List[int]]], optional + Device index/indices to use. If provided, this overrides the devices in + the runtime_defaults configuration. + number_of_nodes : Optional[int], optional + Number of nodes to use. If provided, this overrides the number of nodes + in the runtime_defaults configuration. + experiment_name : Optional[str], optional + Name of the experiment. If provided, this overrides the experiment name + in the runtime_defaults configuration. + save_dir : Optional[str], optional + Directory to save the model. If provided, this overrides the save + directory in the runtime_defaults configuration. + local_cache_dir : Optional[str], optional + Local cache directory. If provided, this overrides the local cache + directory in the runtime_defaults configuration. + checkpoint_path : Optional[str], optional + Path to the checkpoint file. If provided, this overrides the checkpoint + path in the runtime_defaults configuration. + log_every_n_steps : Optional[int], optional + Number of steps to log. If provided, this overrides the + log_every_n_steps in the runtime_defaults configuration. + simulation_environment : Optional[str], optional + Simulation environment. If provided, this overrides the simulation + environment in the runtime_defaults configuration. Returns ------- - Trainer + Tuple[TrainingParameters, DatasetParameters, T_NNP_Parameters, + RuntimeParameters] + Tuple containing the training, dataset, potential, and runtime + parameters. """ + import toml - from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger - - from modelforge.dataset.utils import REGISTERED_SPLITTING_STRATEGIES - from lightning import Trainer - from modelforge.potential import NeuralNetworkPotentialFactory - from modelforge.dataset.dataset import DataModule - - # NOTE --------------------------------------- NOTE # - # FIXME TODO: move this to a dataclass and control default - # behavior from there this current approach is hacky and error prone - save_dir = runtime_config["save_dir"] - log.info(f"Saving logs to location: {save_dir}") - - experiment_name = runtime_config["experiment_name"] - if experiment_name == "{model_name}_{dataset_name}": - experiment_name = ( - f"{potential_config['model_name']}_{dataset_config['dataset_name']}" - ) - - model_name = potential_config["model_name"] - dataset_name = dataset_config["dataset_name"] + # Initialize the config dictionaries + training_config_dict = {} + dataset_config_dict = {} + potential_config_dict = {} + runtime_config_dict = {} - log_training_arguments( - potential_config, training_config, dataset_config, runtime_config - ) + if condensed_config_path is not None: + # Load all configurations from a single condensed TOML file + config = toml.load(condensed_config_path) + log.info(f"Reading config from : {condensed_config_path}") - version_select = dataset_config.get("version_select", "latest") - accelerator = runtime_config.get("accelerator", "cpu") - splitting_strategy = training_config["splitting_strategy"] - nr_of_epochs = runtime_config.get("nr_of_epochs", 10) - num_nodes = runtime_config.get("num_nodes", 1) - devices = runtime_config.get("devices", 1) - batch_size = training_config.get("batch_size", 128) - remove_self_energies = training_config.get("remove_self_energies", False) - early_stopping_config = training_config.get("early_stopping", None) - stochastic_weight_averaging_config = training_config.get( - "stochastic_weight_averaging_config", None - ) - num_workers = dataset_config.get("number_of_worker", 4) - pin_memory = dataset_config.get("pin_memory", False) - local_cache_dir = runtime_config.get("local_cache_dir", "./") - # NOTE --------------------------------------- NOTE # - # FIXME TODO: move this to a dataclass and control default - # behavior from there this current approach is hacky and error prone - - # set up tensor board logger - if training_config["experiment_logger"]["logger_name"].lower() == "tensorboard": - logger = TensorBoardLogger(save_dir, name=experiment_name) - elif training_config["experiment_logger"]["logger_name"].lower() == "wandb": - logger = WandbLogger(save_dir=save_dir, log_model=True, name=experiment_name) + training_config_dict = config.get("training", {}) + dataset_config_dict = config.get("dataset", {}) + potential_config_dict = config.get("potential", {}) + runtime_config_dict = config.get("runtime", {}) else: - raise ValueError(f"Unknown logger name: {training_config['logger_name']}") - # Set up dataset - dm = DataModule( - name=dataset_name, - batch_size=batch_size, - remove_self_energies=remove_self_energies, - version_select=version_select, - local_cache_dir=local_cache_dir, - splitting_strategy=REGISTERED_SPLITTING_STRATEGIES[splitting_strategy["name"]]( - seed=splitting_strategy.get("splitting_seed", 42), - split=splitting_strategy["data_split"], - ), + if training_parameter_path: + training_config_dict = toml.load(training_parameter_path).get( + "training", {} + ) + if dataset_parameter_path: + dataset_config_dict = toml.load(dataset_parameter_path).get("dataset", {}) + if potential_parameter_path: + potential_config_dict = toml.load(potential_parameter_path).get( + "potential", {} + ) + if runtime_parameter_path: + runtime_config_dict = toml.load(runtime_parameter_path).get("runtime", {}) + + # Override runtime configuration with command-line arguments if provided + runtime_overrides = { + "accelerator": accelerator, + "devices": devices, + "number_of_nodes": number_of_nodes, + "experiment_name": experiment_name, + "save_dir": save_dir, + "local_cache_dir": local_cache_dir, + "checkpoint_path": checkpoint_path, + "log_every_n_steps": log_every_n_steps, + "simulation_environment": simulation_environment, + } + + for key, value in runtime_overrides.items(): + if value is not None: + runtime_config_dict[key] = value + + # Load and instantiate the data classes with the merged configuration + from modelforge.dataset.dataset import DatasetParameters + from modelforge.potential import _Implemented_NNP_Parameters + from modelforge.train.parameters import RuntimeParameters, TrainingParameters + + potential_name = potential_config_dict["potential_name"] + PotentialParameters = ( + _Implemented_NNP_Parameters.get_neural_network_parameter_class(potential_name) ) - dm.prepare_data() - dm.setup() - # read dataset statistics - import toml + dataset_parameters = DatasetParameters(**dataset_config_dict) + training_parameters = TrainingParameters(**training_config_dict) + runtime_parameters = RuntimeParameters(**runtime_config_dict) + potential_parameter = PotentialParameters(**potential_config_dict) - dataset_statistic = toml.load(dm.dataset_statistic_filename) - log.info( - f"Setting per_atom_energy_mean and per_atom_energy_stddev for {model_name}" - ) - log.info( - f"per_atom_energy_mean: {dataset_statistic['training_dataset_statistics']['per_atom_energy_mean']}" - ) - log.info( - f"per_atom_energy_stddev: {dataset_statistic['training_dataset_statistics']['per_atom_energy_stddev']}" + return ( + training_parameters, + dataset_parameters, + potential_parameter, + runtime_parameters, ) - # Set up model - model = NeuralNetworkPotentialFactory.generate_model( - use="training", - dataset_statistic=dataset_statistic, - model_parameter=potential_config, - training_parameter=training_config["training_parameter"], - ) - # set up traininer - from lightning.pytorch.callbacks.early_stopping import EarlyStopping - from lightning.pytorch.callbacks.stochastic_weight_avg import ( - StochasticWeightAveraging, - ) - - # set up callbacks - callbacks = [] - if stochastic_weight_averaging_config: - callbacks.append( - StochasticWeightAveraging(**stochastic_weight_averaging_config) - ) - if early_stopping_config: - log.warning("No early stopping is defined. Do you have resources to waste?") - callbacks.append(EarlyStopping(**early_stopping_config)) - - from lightning.pytorch.callbacks import ModelCheckpoint - - checkpoint_callback = ModelCheckpoint( - save_top_k=2, - monitor="val/per_molecule_energy/rmse", - filename="best_model", - ) +def read_config_and_train( + condensed_config_path: Optional[str] = None, + training_parameter_path: Optional[str] = None, + dataset_parameter_path: Optional[str] = None, + potential_parameter_path: Optional[str] = None, + runtime_parameter_path: Optional[str] = None, + accelerator: Optional[str] = None, + devices: Optional[Union[int, List[int]]] = None, + number_of_nodes: Optional[int] = None, + experiment_name: Optional[str] = None, + save_dir: Optional[str] = None, + local_cache_dir: Optional[str] = None, + checkpoint_path: Optional[str] = None, + log_every_n_steps: Optional[int] = None, + simulation_environment: Optional[str] = "PyTorch", +): + """ + Reads one or more TOML configuration files and performs training based on the parameters. - callbacks.append(checkpoint_callback) + Parameters + ---------- + condensed_config_path : str, optional + Path to the TOML configuration that contains all parameters for the dataset, potential, training, and runtime parameters. + Any other provided configuration files will be ignored. + training_parameter_path : str, optional + Path to the TOML file defining the training parameters. + dataset_parameter_path : str, optional + Path to the TOML file defining the dataset parameters. + potential_parameter_path : str, optional + Path to the TOML file defining the potential parameters. + runtime_parameter_path : str, optional + Path to the TOML file defining the runtime parameters. If this is not provided, the code will attempt to use + the runtime parameters provided as arguments. + accelerator : str, optional + Accelerator type to use. If provided, this overrides the accelerator type in the runtime_defaults configuration. + devices : int|List[int], optional + Device index/indices to use. If provided, this overrides the devices in the runtime_defaults configuration. + number_of_nodes : int, optional + Number of nodes to use. If provided, this overrides the number of nodes in the runtime_defaults configuration. + experiment_name : str, optional + Name of the experiment. If provided, this overrides the experiment name in the runtime_defaults configuration. + save_dir : str, optional + Directory to save the model. If provided, this overrides the save directory in the runtime_defaults configuration. + local_cache_dir : str, optional + Local cache directory. If provided, this overrides the local cache directory in the runtime_defaults configuration. + checkpoint_path : str, optional + Path to the checkpoint file. If provided, this overrides the checkpoint path in the runtime_defaults configuration. + log_every_n_steps : int, optional + Number of steps to log. If provided, this overrides the log_every_n_steps in the runtime_defaults configuration. + simulation_environment : str, optional + Simulation environment. If provided, this overrides the simulation environment in the runtime_defaults configuration. - # set up trainer - trainer = Trainer( - max_epochs=nr_of_epochs, - num_nodes=num_nodes, - devices=devices, + Returns + ------- + Trainer + The configured trainer instance after running the training process. + """ + from modelforge.potential.potential import NeuralNetworkPotentialFactory + + ( + training_parameter, + dataset_parameter, + potential_parameter, + runtime_parameter, + ) = read_config( + condensed_config_path=condensed_config_path, + training_parameter_path=training_parameter_path, + dataset_parameter_path=dataset_parameter_path, + potential_parameter_path=potential_parameter_path, + runtime_parameter_path=runtime_parameter_path, accelerator=accelerator, - logger=logger, # Add the logger here - callbacks=callbacks, - inference_mode=False, - num_sanity_val_steps=2, - log_every_n_steps=50, + devices=devices, + number_of_nodes=number_of_nodes, + experiment_name=experiment_name, + save_dir=save_dir, + local_cache_dir=local_cache_dir, + checkpoint_path=checkpoint_path, + log_every_n_steps=log_every_n_steps, + simulation_environment=simulation_environment, ) - # Run training loop and validate - trainer.fit( - model, - train_dataloaders=dm.train_dataloader( - num_workers=num_workers, pin_memory=pin_memory - ), - val_dataloaders=dm.val_dataloader(), - ckpt_path=checkpoint_path, + trainer = NeuralNetworkPotentialFactory.generate_trainer( + potential_parameter=potential_parameter, + training_parameter=training_parameter, + dataset_parameter=dataset_parameter, + runtime_parameter=runtime_parameter, ) - trainer.validate( - model=model, dataloaders=dm.val_dataloader(), ckpt_path="best", verbose=True - ) - trainer.test(dataloaders=dm.test_dataloader(), ckpt_path="best", verbose=True) - return trainer \ No newline at end of file + return trainer.train_potential() diff --git a/modelforge/train/tuning.py b/modelforge/train/tuning.py index 13c08387..bef07ba7 100644 --- a/modelforge/train/tuning.py +++ b/modelforge/train/tuning.py @@ -1,13 +1,14 @@ +""" +This module contains functions and classes for hyperparameter tuning and distributed training using Ray Tune. +""" + import torch -from modelforge.utils.io import import_ +from ray import air, tune -air = import_("ray").air -tune = import_("ray").tune -# from ray import air, tune -ASHAScheduler = import_("ray").tune.scheduleres.ASHAScheduler -# from ray.tune.schedulers import ASHAScheduler +from typing import Type +from ray.tune.schedulers import ASHAScheduler def tune_model( @@ -68,7 +69,16 @@ def objective(): class RayTuner: - def __init__(self, model) -> None: + + def __init__(self, model: Type[torch.nn.Module]) -> None: + """ + Initializes the RayTuner with the given model. + + Parameters + ---------- + model : torch.nn.Module + The model to be tuned and trained using Ray. + """ self.model = model def train_func(self): @@ -90,6 +100,7 @@ def train_func(self): ) import lightning as pl + # Configure PyTorch Lightning trainer with Ray DDP strategy trainer = pl.Trainer( devices="auto", accelerator="auto", @@ -99,6 +110,7 @@ def train_func(self): enable_progress_bar=False, ) trainer = prepare_trainer(trainer) + # Fit the model using the trainer trainer.fit(self.model, self.train_dataloader, self.val_dataloader) def get_ray_trainer(self, number_of_workers: int = 2, use_gpu: bool = False): @@ -117,18 +129,21 @@ def get_ray_trainer(self, number_of_workers: int = 2, use_gpu: bool = False): Returns ------- - Ray Trainer + TorchTrainer The configured Ray Trainer for distributed training. """ from ray.train import CheckpointConfig, RunConfig, ScalingConfig + from ray.train.torch import TorchTrainer + # Configure scaling for Ray Trainer scaling_config = ScalingConfig( num_workers=number_of_workers, use_gpu=use_gpu, resources_per_worker={"CPU": 1, "GPU": 1} if use_gpu else {"CPU": 1}, ) + # Configure run settings for Ray Trainer run_config = RunConfig( checkpoint_config=CheckpointConfig( num_to_keep=2, @@ -136,9 +151,7 @@ def get_ray_trainer(self, number_of_workers: int = 2, use_gpu: bool = False): checkpoint_score_order="min", ), ) - from ray.train.torch import TorchTrainer - - # Define a TorchTrainer without hyper-parameters for Tuner + # Define and return the TorchTrainer ray_trainer = TorchTrainer( self.train_func, scaling_config=scaling_config, @@ -155,6 +168,7 @@ def tune_with_ray( number_of_samples: int = 10, number_of_ray_workers: int = 2, train_on_gpu: bool = False, + metric: str = "val/per_system_energy/rmse", ): """ Performs hyperparameter tuning using Ray Tune. @@ -174,39 +188,43 @@ def tune_with_ray( The number of samples (trial runs) to perform, by default 10. number_of_ray_workers : int, optional The number of Ray workers to use for distributed training, by default 2. - use_gpu : bool, optional + train_on_gpu : bool, optional Whether to use GPUs for training, by default False. + metric : str, optional + The metric to use for evaluation and early stopping, by default "val/per_system_energy/rmse Returns ------- - Tune experiment analysis object - The result of the hyperparameter tuning session, containing performance metrics and the best hyperparameters found. + ExperimentAnalysis + The result of the hyperparameter tuning session, containing performance metrics + and the best hyperparameters found. """ - from modelforge.utils.io import import_ - tune = import_("ray").tune - # from ray import tune + from ray import tune - ASHAScheduler = import_("ray").tune.schedulers.ASHAScheduler - # from ray.tune.schedulers import ASHAScheduler + from ray.tune.schedulers import ASHAScheduler self.train_dataloader = train_dataloader self.val_dataloader = val_dataloader + # Initialize Ray Trainer ray_trainer = self.get_ray_trainer( number_of_workers=number_of_ray_workers, use_gpu=train_on_gpu ) + # Configure ASHA scheduler for early stopping scheduler = ASHAScheduler( max_t=number_of_epochs, grace_period=1, reduction_factor=2 ) + # Define tuning configuration tune_config = tune.TuneConfig( - metric="val/energy/rmse", + metric=metric, mode="min", scheduler=scheduler, num_samples=number_of_samples, ) + # Initialize and run the tuner tuner = tune.Tuner( ray_trainer, param_space={"train_loop_config": self.model.config_prior()}, diff --git a/modelforge/train/utils.py b/modelforge/train/utils.py index 80468b48..133fd4b6 100644 --- a/modelforge/train/utils.py +++ b/modelforge/train/utils.py @@ -1,8 +1,6 @@ def shared_config_prior(): - from modelforge.utils.io import import_ - tune = import_("ray").tune - # from ray import tune + from ray import tune return { "lr": tune.loguniform(1e-5, 1e-1), diff --git a/modelforge/utils/__init__.py b/modelforge/utils/__init__.py index d97f639b..605aea59 100644 --- a/modelforge/utils/__init__.py +++ b/modelforge/utils/__init__.py @@ -1,3 +1,4 @@ -"""modelforge utilities.""" +"""Module of general modelforge utilities.""" from .prop import SpeciesEnergies, PropertyNames +from .misc import lock_with_attribute diff --git a/modelforge/utils/io.py b/modelforge/utils/io.py index 9dd5f4d5..8f0eab42 100644 --- a/modelforge/utils/io.py +++ b/modelforge/utils/io.py @@ -1,4 +1,4 @@ -""" Module for handling importing of external libraries. +""" Module for handling importing external libraries that are optional dependencies. The general approach here is to wrap an import in a try/except structure, where failure to import a module results in a descriptive message being printed to the console, e.g., @@ -148,6 +148,18 @@ """ +MESSAGES[ + "wandb" +] = """ + +Weights and Biases is a tool for tracking and visualizing machine learning experiments. + +Weights and Biases can be installed via conda: + + conda install conda-forge::wandb + +""" + def import_(module: str): """Import a module or print a descriptive message and raise an ImportError @@ -225,10 +237,39 @@ def check_import(module: str): and an ImportError is raised. Examples -------- - >>> from modelforge.utils.package_import import check_import + >>> from modelforge.utils.io import check_import >>> check_import(module="ray") >>> from ray import tune """ imported_module = import_(module) del imported_module + + +from typing import Union, List + + +def parse_devices(value: str) -> Union[int, List[int]]: + """ + Parse the devices argument which can be either a single integer or a list of + integers. + + Parameters + ---------- + value : str + The input string representing either a single integer or a list of + integers. + + Returns + ------- + Union[int, List[int]] + Either a single integer or a list of integers. + """ + import ast + + # if multiple comma delimited values are passed, split them into a list + if value.startswith("[") and value.endswith("]"): + # Safely evaluate the string as a Python literal (list of ints) + return list(ast.literal_eval(value)) + else: + return int(value) diff --git a/modelforge/utils/misc.py b/modelforge/utils/misc.py index 1557a83e..cbfcf329 100644 --- a/modelforge/utils/misc.py +++ b/modelforge/utils/misc.py @@ -1,28 +1,9 @@ -from typing import Literal +""" +Module of miscellaneous utilities. +""" import torch from loguru import logger -from modelforge.dataset.dataset import DataModule - - -def visualize_model( - dm: DataModule, model_name: Literal["ANI2x", "PhysNet", "SchNet", "PaiNN", "SAKE"] -): - # visualize the compute graph - from modelforge.utils.io import import_ - - torchviz = import_("torchviz") - from modelforge.potential import NeuralNetworkPotentialFactory - - inference_model = NeuralNetworkPotentialFactory.generate_model( - "inference", model_name - ) - - nnp_input = next(iter(dm.train_dataloader())).nnp_input - yhat = inference_model(nnp_input) - torchviz.make_dot( - yhat, params=dict(list(inference_model.named_parameters())) - ).render(f"compute_graph_{inference_model.__class__.__name__}", format="png") class Welford: @@ -292,20 +273,129 @@ def __enter__(self): f"{self._file_path} in locked by another process; waiting until lock is released." ) - # try to lock the file; if the file is already locked, this will wait until the file is released - # I added helper function definitions that call fcntl, as we might not always want to use a context manager - # and to test the function in isolation + # try to lock the file; if the file is already locked, this will wait + # until the file is released. I added helper function definitions that + # call fcntl, as we might not always want to use a context manager and + # to test the function in isolation lock_file(self._file_handle) - # import fcntl - # fcntl.flock(self._file_handle.fileno(), fcntl.LOCK_EX) # return the opened file stream return self._file_handle def __exit__(self, *args): # unlock the file and close the file stream - # import fcntl - # fcntl.flock(self._file_handle.fileno(), fcntl.LOCK_UN) unlock_file(self._file_handle) self._file_handle.close() + + +import os +from functools import wraps + + +def lock_with_attribute(attribute_name): + """ + Decorator for locking a method using a lock file path stored in an instance + attribute. The attribute is accessed on the instance (`self`) at runtime. + + Parameters + ---------- + attribute_name : str + The name of the instance attribute that contains the lock file path. + + Examples + -------- + >>> from modelforge.utils.misc import lock_with_attribute + >>> + >>> class MyClass: + >>> def __init__(self, lock_file): + >>> self.method_lock = lock_file + >>> + >>> @lock_with_attribute('method_lock') + >>> def critical_section(self): + >>> print("Executing critical section") + """ + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + # Retrieve the instance (`self`) + instance = args[0] + # Get the lock file path from the specified attribute + lock_file_path = getattr(instance, attribute_name) + + with OpenWithLock(lock_file_path, "w+") as lock_file_handle: + return func(*args, **kwargs) + + return wrapper + + return decorator + + +def seed_random_number(seed: int): + """ + Seed the random number generator for reproducibility. + + Parameters + ---------- + seed : int + The seed for the random number generator. + """ + import random + + random.seed(seed) + + import numpy as np + + np.random.seed(seed) + import torch + + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def load_configs_into_pydantic_models(potential_name: str, dataset_name: str): + from modelforge.tests.data import ( + potential_defaults, + training_defaults, + dataset_defaults, + runtime_defaults, + ) + from importlib import resources + import toml + + potential_path = ( + resources.files(potential_defaults) / f"{potential_name.lower()}.toml" + ) + dataset_path = resources.files(dataset_defaults) / f"{dataset_name.lower()}.toml" + training_path = resources.files(training_defaults) / "default.toml" + runtime_path = resources.files(runtime_defaults) / "runtime.toml" + + training_config_dict = toml.load(training_path) + dataset_config_dict = toml.load(dataset_path) + potential_config_dict = toml.load(potential_path) + runtime_config_dict = toml.load(runtime_path) + + potential_name = potential_config_dict["potential"]["potential_name"] + + from modelforge.potential import _Implemented_NNP_Parameters + + PotentialParameters = ( + _Implemented_NNP_Parameters.get_neural_network_parameter_class(potential_name) + ) + potential_parameters = PotentialParameters(**potential_config_dict["potential"]) + + from modelforge.dataset.dataset import DatasetParameters + from modelforge.train.parameters import TrainingParameters, RuntimeParameters + + dataset_parameters = DatasetParameters(**dataset_config_dict["dataset"]) + training_parameters = TrainingParameters(**training_config_dict["training"]) + runtime_parameters = RuntimeParameters(**runtime_config_dict["runtime"]) + + return { + "potential": potential_parameters, + "dataset": dataset_parameters, + "training": training_parameters, + "runtime": runtime_parameters, + } diff --git a/modelforge/utils/prop.py b/modelforge/utils/prop.py index 81aaad68..5990f3bc 100644 --- a/modelforge/utils/prop.py +++ b/modelforge/utils/prop.py @@ -1,7 +1,12 @@ +""" +Module of dataclass definitions of properties. +""" + from dataclasses import dataclass import torch from typing import NamedTuple, Optional -from loguru import logger +from loguru import logger as log +from openff.units import unit @dataclass @@ -11,6 +16,17 @@ class PropertyNames: E: str F: Optional[str] = None total_charge: Optional[str] = None + dipole_moment: Optional[str] = None + + +PropertyUnits = { + "atomic_numbers": "dimensionless", + "positions": unit.nanometer, + "E": unit.kilojoule_per_mole, + "F": unit.kilojoule_per_mole / unit.nanometer, + "total_charge": unit.elementary_charge, + "dipole_moment": unit.elementary_charge * unit.nanometer, +} class SpeciesEnergies(NamedTuple): @@ -21,3 +37,193 @@ class SpeciesEnergies(NamedTuple): class SpeciesAEV(NamedTuple): species: torch.Tensor aevs: torch.Tensor + + +class NNPInput: + __slots__ = ( + "atomic_numbers", + "positions", + "atomic_subsystem_indices", + "per_system_total_charge", + "pair_list", + "per_atom_partial_charge", + "box_vectors", + "is_periodic", + ) + + def __init__( + self, + atomic_numbers: torch.Tensor, + positions: torch.Tensor, + atomic_subsystem_indices: torch.Tensor, + per_system_total_charge: torch.Tensor, + box_vectors: torch.Tensor = torch.zeros(3, 3), + is_periodic: torch.Tensor = torch.tensor([False]), + pair_list: torch.Tensor = torch.tensor([]), + per_atom_partial_charge: torch.Tensor = torch.tensor([]), + ): + self.atomic_numbers = atomic_numbers + self.positions = positions + self.atomic_subsystem_indices = atomic_subsystem_indices + self.per_system_total_charge = per_system_total_charge + self.pair_list = pair_list + self.per_atom_partial_charge = per_atom_partial_charge + self.box_vectors = box_vectors + self.is_periodic = is_periodic + + # Validate inputs + self._validate_inputs() + + def _validate_inputs(self): + # Get shapes of the arrays + atomic_numbers_shape = self.atomic_numbers.shape + positions_shape = self.positions.shape + atomic_subsystem_indices_shape = self.atomic_subsystem_indices.shape + + # Validate dimensions + if len(atomic_numbers_shape) != 1: + raise ValueError("atomic_numbers must be a 1D tensor or array") + if len(positions_shape) != 2 or positions_shape[1] != 3: + raise ValueError( + "positions must be a 2D tensor or array with shape [num_atoms, 3]" + ) + if self.box_vectors.shape[0] != 3 or self.box_vectors.shape[1] != 3: + print(f"{self.box_vectors.shape}") + raise ValueError("box_vectors must be a 3x3 tensor or array") + + if len(atomic_subsystem_indices_shape) != 1: + raise ValueError("atomic_subsystem_indices must be a 1D tensor or array") + + # Validate lengths + num_atoms = positions_shape[0] + if atomic_numbers_shape[0] != num_atoms: + raise ValueError( + "The size of atomic_numbers and the first dimension of positions must match" + ) + if atomic_subsystem_indices_shape[0] != num_atoms: + raise ValueError( + "The size of atomic_subsystem_indices and the first dimension of positions must match" + ) + + def to_device(self, device: torch.device): + """Move all tensors in this instance to the specified device.""" + + self.atomic_numbers = self.atomic_numbers.to(device) + self.positions = self.positions.to(device) + self.atomic_subsystem_indices = self.atomic_subsystem_indices.to(device) + self.per_system_total_charge = self.per_system_total_charge.to(device) + self.box_vectors = self.box_vectors.to(device) + self.is_periodic = self.is_periodic.to(device) + self.pair_list = self.pair_list.to(device) + self.per_atom_partial_charge = self.per_atom_partial_charge.to(device) + + return self + + def to_dtype(self, dtype: torch.dtype): + """Move all **relevant** tensors to dtype.""" + self.positions = self.positions.to(dtype) + self.box_vectors = self.box_vectors.to(dtype) + return self + + +class Metadata: + """ + A class to structure metadata for neural network potentials. + + Parameters + ---------- + per_system_energy : torch.Tensor + Energies for each system. + atomic_subsystem_counts : torch.Tensor + The number of atoms in each subsystem. + atomic_subsystem_indices_referencing_dataset : torch.Tensor + Indices referencing the dataset. + number_of_atoms : int + Total number of atoms. + per_atom_force : torch.Tensor, optional + Forces for each atom. + per_system_dipole_moment : torch.Tensor, optional + Dipole moments for each system. + + """ + + __slots__ = ( + "per_system_energy", + "atomic_subsystem_counts", + "atomic_subsystem_indices_referencing_dataset", + "number_of_atoms", + "per_atom_force", + "per_system_dipole_moment", + ) + + def __init__( + self, + per_system_energy: torch.Tensor, + atomic_subsystem_counts: torch.Tensor, + atomic_subsystem_indices_referencing_dataset: torch.Tensor, + number_of_atoms: int, + per_atom_force: torch.Tensor = None, + per_system_dipole_moment: torch.Tensor = None, + ): + self.per_system_energy = per_system_energy + self.atomic_subsystem_counts = atomic_subsystem_counts + self.atomic_subsystem_indices_referencing_dataset = ( + atomic_subsystem_indices_referencing_dataset + ) + self.number_of_atoms = number_of_atoms + self.per_atom_force = per_atom_force + self.per_system_dipole_moment = per_system_dipole_moment + + def to_device(self, device: torch.device): + """Move all tensors in this instance to the specified device.""" + self.per_system_energy = self.per_system_energy.to(device) + self.per_atom_force = self.per_atom_force.to(device) + self.atomic_subsystem_counts = self.atomic_subsystem_counts.to(device) + self.atomic_subsystem_indices_referencing_dataset = ( + self.atomic_subsystem_indices_referencing_dataset.to(device) + ) + self.per_system_dipole_moment = self.per_system_dipole_moment.to(device) + return self + + def to_dtype(self, dtype: torch.dtype): + self.per_system_energy = self.per_system_energy.to(dtype) + self.per_atom_force = self.per_atom_force.to(dtype) + self.per_system_dipole_moment = self.per_system_dipole_moment.to(dtype) + return self + + +@dataclass +class BatchData: + nnp_input: NNPInput + metadata: Metadata + + def to( + self, + device: torch.device, + ): # NOTE: this is required to move the data to device + """Move all data in this batch to the specified device and dtype.""" + self.nnp_input = self.nnp_input.to_device(device=device) + self.metadata = self.metadata.to_device(device=device) + return self + + def to_device( + self, + device: torch.device, + ): + """Move all data in this batch to the specified device and dtype.""" + self.nnp_input = self.nnp_input.to_device(device=device) + self.metadata = self.metadata.to_device(device=device) + return self + + def to_dtype( + self, + dtype: torch.dtype, + ): + """Move all data in this batch to the specified device and dtype.""" + self.nnp_input = self.nnp_input.to_dtype(dtype=dtype) + self.metadata = self.metadata.to_dtype(dtype=dtype) + return self + + def batch_size(self) -> int: + """Return the batch size.""" + return self.metadata.per_system_energy.size(dim=0) diff --git a/modelforge/utils/remote.py b/modelforge/utils/remote.py index ee2778af..a3d3f237 100644 --- a/modelforge/utils/remote.py +++ b/modelforge/utils/remote.py @@ -1,4 +1,4 @@ -"""Module for querying and fetching datafiles from remote sources""" +"""Module for querying remote sources and fetching datafiles""" from typing import Optional, List, Dict from loguru import logger @@ -111,7 +111,7 @@ def download_from_url( output_filename: str, length: Optional[int] = None, force_download=False, -) -> str: +): import requests import os @@ -173,243 +173,3 @@ def download_from_url( logger.debug( "Using previously downloaded file; set force_download=True to re-download." ) - - -# Figshare helper functions -def download_from_figshare( - url: str, md5_checksum: str, output_path: str, force_download=False -) -> str: - """ - Downloads a dataset from figshare for a given ndownloader url. - - Parameters - ---------- - url: str, required - Figshare ndownloader url (i.e., link to the data downloader) - md5_checksum: str, required - Expected md5 checksum of the downloaded file. - output_path: str, required - Location to download the file to. - force_download: str, default=False - If False: if the file exists in output_path, code will will use the local version. - If True, the file will be downloaded, even if it exists in output_path. - - Returns - ------- - str - Name of the file downloaded. - - Examples - -------- - >>> url = 'https://springernature.figshare.com/ndownloader/files/18112775' - >>> output_path = '/path/to/directory' - >>> downloaded_file_name = download_from_figshare(url, output_path) - - """ - - import requests - import os - from tqdm import tqdm - - # force to use ipv4; my ubuntu machine is timing out when it first tries ipv6 - # requests.packages.urllib3.util.connection.HAS_IPV6 = False - - chunk_size = 512 - # check to make sure the url we are given is hosted by figshare.com - if not is_url(url, "figshare.com"): - raise Exception(f"{url} is not a valid figshare.com url") - - # get the head of the request - head = requests.head(url) - - # Because the url on figshare calls a downloader, instead of the direct file, - # we need to figure out where the original file is stored to know how big it is. - # Here we will parse the header info to get the file the downloader links to - # and then get the head info from this link to fetch the length. - # This is not actually necessary, but useful for updating the download status bar. - # We also fetch the name of the file from the header of the download link - - temp_url = head.headers["location"].split("?")[0] - name = head.headers["X-Filename"].split("/")[-1] - - # make sure we can handle a path with a ~ in it - output_path = os.path.expanduser(output_path) - - # We need to check to make sure that the file that is stored in the output path - # has the correct checksum, e.g., to avoid a case where we have a partially downloaded file - # or to make sure we don't have two files with the same name, but different content. - if os.path.isfile(f"{output_path}/{name}"): - calculated_checksum = calculate_md5_checksum( - file_name=name, file_path=output_path - ) - if calculated_checksum != md5_checksum: - force_download = True - logger.debug( - "Checksum of existing file does not match expected checksum, re-downloading." - ) - - if not os.path.isfile(f"{output_path}/{name}") or force_download: - logger.debug(f"Downloading datafile from figshare to {output_path}/{name}.") - - temp_url_headers = requests.head(temp_url) - - os.makedirs(output_path, exist_ok=True) - try: - length = int(temp_url_headers.headers["Content-Length"]) - except: - print( - "Could not determine the length of the file to download. The download bar will not be accurate." - ) - length = -1 - r = requests.get(url, stream=True) - - from modelforge.utils.misc import OpenWithLock - - with OpenWithLock(f"{output_path}/{name}.lockfile", "w") as fl: - with open(f"{output_path}/{name}", "wb") as fd: - # if we couldn't fetch the length from figshare, which seems to happen for some records - # we just don't know how long the tqdm bar will be. - if length == -1: - for chunk in tqdm( - r.iter_content(chunk_size=chunk_size), - ascii=True, - desc="downloading", - ): - fd.write(chunk) - else: - for chunk in tqdm( - r.iter_content(chunk_size=chunk_size), - ascii=True, - desc="downloading", - total=(int(length / chunk_size) + 1), - ): - fd.write(chunk) - os.remove(f"{output_path}/{name}.lockfile") - - calculated_checksum = calculate_md5_checksum( - file_name=name, file_path=output_path - ) - if calculated_checksum != md5_checksum: - raise Exception( - f"Checksum of downloaded file {calculated_checksum} does not match expected checksum {md5_checksum}" - ) - else: # if the file exists and we don't set force_download to True, just use the cached version - logger.debug(f"Datafile {name} already exists in {output_path}.") - logger.debug( - "Using previously downloaded file; set force_download=True to re-download." - ) - - return name - - -def download_from_zenodo( - url: str, md5_checksum: str, output_path: str, force_download=False -) -> str: - """ - Downloads a dataset from zenodo for a given url. - - If the datafile exists in the output_path, by default it will not be redownloaded. - - Parameters - ---------- - url : str, required - Direct link to datafile to download. - md5_checksum: str, required - Expected md5 checksum of the downloaded file. - output_path: str, required - Location to download the file to. - force_download: str, default=False - If False: if the file exists in output_path, code will will use the local version. - If True, the file will be downloaded, even if it exists in output_path. - - Returns - ------- - str - Name of the file downloaded. - - Examples - -------- - >>> url = "https://zenodo.org/records/3401581/files/PTC-CMC/atools_ml-v0.1.zip" - >>> output_path = '/path/to/directory' - >>> md5_checksum = "d41d8cd98f00b204e9800998ecf8427e" - >>> downloaded_file_name = download_from_zenodo(url, md5_checksum, output_path) - - """ - - import requests - import os - from tqdm import tqdm - - # force to use ipv4; my ubuntu machine is timing out when it first tries ipv6 - # requests.packages.urllib3.util.connection.HAS_IPV6 = False - - chunk_size = 512 - # check to make sure the url we are given is hosted by figshare.com - - if not is_url(url, "zenodo.org"): - raise Exception(f"{url} is not a valid zenodo.org url") - - # get the head of the request - head = requests.head(url) - - # Because the url on figshare calls a downloader, instead of the direct file, - # we need to figure out where the original file is stored to know how big it is. - # Here we will parse the header info to get the file the downloader links to - # and then get the head info from this link to fetch the length. - # This is not actually necessary, but useful for updating the download status bar. - # We also fetch the name of the file from the header of the download link - name = head.headers["Content-Disposition"].split("filename=")[-1] - length = int(head.headers["Content-Length"]) - - # make sure we can handle a path with a ~ in it - output_path = os.path.expanduser(output_path) - - # We need to check to make sure that the file that is stored in the output path - # has the correct checksum, e.g., to avoid a case where we have a partially downloaded file - # or to make sure we don't have two files with the same name, but different content. - - if os.path.isfile(f"{output_path}/{name}"): - calculated_checksum = calculate_md5_checksum( - file_name=name, file_path=output_path - ) - if calculated_checksum != md5_checksum: - force_download = True - logger.debug( - "Checksum of existing file does not match expected checksum, re-downloading." - ) - - if not os.path.isfile(f"{output_path}/{name}") or force_download: - logger.debug(f"Downloading datafile from zenodo to {output_path}/{name}.") - - r = requests.get(url, stream=True) - - os.makedirs(output_path, exist_ok=True) - - from modelforge.utils.misc import OpenWithLock - - with OpenWithLock(f"{output_path}/{name}.lockfile", "w") as fl: - with open(f"{output_path}/{name}", "wb") as fd: - for chunk in tqdm( - r.iter_content(chunk_size=chunk_size), - ascii=True, - desc="downloading", - total=(int(length / chunk_size) + 1), - ): - fd.write(chunk) - os.remove(f"{output_path}/{name}.lockfile") - - calculated_checksum = calculate_md5_checksum( - file_name=name, file_path=output_path - ) - if calculated_checksum != md5_checksum: - raise Exception( - f"Checksum of downloaded file {calculated_checksum} does not match expected checksum {md5_checksum}." - ) - - else: # if the file exists and we don't set force_download to True, just use the cached version - logger.debug(f"Datafile {name} already exists in {output_path}.") - logger.debug( - "Using previously downloaded file; set force_download=True to re-download." - ) - - return name diff --git a/modelforge/utils/units.py b/modelforge/utils/units.py index f9b12806..c8b141c4 100644 --- a/modelforge/utils/units.py +++ b/modelforge/utils/units.py @@ -1,11 +1,22 @@ +""" +Module that handles unit system definitions and conversions. + +This module defines a custom unit context for converting between various +energy units and includes utility functions for handling units within +the model forge framework. +""" + from typing import Union from openff.units import unit -# define new context for converting energy (e.g., hartree) -# to energy/mol (e.g., kJ/mol) +# Define a chemical context for unit transformations +# This allows conversions between energy units like hartree and kJ/mol __all__ = ["chem_context"] chem_context = unit.Context("chem") + +# Add transformations to handle conversions between energy units per substance +# (mole) and other forms chem_context.add_transformation( "[force] * [length]", "[force] * [length]/[substance]", @@ -38,15 +49,66 @@ lambda unit, x: x / unit.avogadro_constant, ) +# Register the custom chemical context for use with the unit system +unit.add_context(chem_context) + + +def _convert_str_or_unit_to_unit_length(val: Union[unit.Quantity, str]) -> float: + """ + Convert a string or unit.Quantity representation of a length to nanometers. + + This function ensures that any input, whether a string or an OpenFF + unit.Quantity, is converted to a unit.Quantity in nanometers and returns the + magnitude. + + Parameters + ---------- + val : Union[unit.Quantity, str] + The value to convert to a unit length (nanometers). -def _convert(val: Union[unit.Quantity, str]) -> unit.Quantity: - """Convert a string representation of an OpenFF unit to a unit.Quantity""" + Returns + ------- + float + The value in nanometers. + + Examples + -------- + >>> _convert_str_or_unit_to_unit_length("1.0 * nanometer") + 1.0 + >>> _convert_str_or_unit_to_unit_length(unit.Quantity(1.0, unit.angstrom)) + 0.1 + """ if isinstance(val, str): - return unit.Quantity(val) - return val + val = unit.Quantity(val) + return val.to(unit.nanometer).m -unit.add_context(chem_context) +def _convert_str_to_unit(val: Union[unit.Quantity, str]) -> unit.Quantity: + """ + Convert a string representation of a unit to an OpenFF unit.Quantity. + + If the input is already a unit.Quantity, it is returned as is. + Parameters + ---------- + val : Union[unit.Quantity, str] + The value to convert to a unit.Quantity + + Returns + ------- + unit.Quantity + The value and unit as a unit.Quantity + + Examples + -------- + >>> _convert_str_to_unit("1.0 * kilocalorie / mole") + Quantity(value=1.0, unit=kilocalorie/mole) + >>> _convert_str_to_unit(unit.Quantity(1.0, unit.kilocalorie / unit.mole)) + Quantity(value=1.0, unit=kilocalorie/mole) + + """ + if isinstance(val, str): + return unit.Quantity(val) + return val def print_modelforge_unit_system(): diff --git a/modelforge/utils/vis.py b/modelforge/utils/vis.py new file mode 100644 index 00000000..79c503a9 --- /dev/null +++ b/modelforge/utils/vis.py @@ -0,0 +1,27 @@ +from modelforge.custom_types import ModelType +from modelforge.dataset.dataset import NNPInput + + +def visualize_model( + nnp_input: NNPInput, + potential_name: ModelType, + output_dir: str, +): + # visualize the compute graph + from modelforge.utils.io import import_ + from modelforge.tests.helper_functions import setup_potential_for_test + + torchviz = import_("torchviz") + + potential = setup_potential_for_test( + potential_name, + "inference", + ) + + yhat = potential(nnp_input)["per_system_energy"] + torchviz.make_dot( + yhat, + params=dict(list(potential.named_parameters())), + show_attrs=True, + show_saved=True, + ).render(f"{output_dir}/compute_graph_{potential_name}1", format="png") diff --git a/notebooks/representation.ipynb b/notebooks/representation.ipynb new file mode 100644 index 00000000..2c02f3cd --- /dev/null +++ b/notebooks/representation.ipynb @@ -0,0 +1,102 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Generate rbf output for each of the different RBF implementations and visualize the output" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from modelforge.potential.representation import PhysNetRadialBasisFunction, AniRadialBasisFunction, SchnetRadialBasisFunction\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Test parameters\n", + "distances = torch.tensor([[0.5], [1.0], [1.5]], dtype=torch.float32) / 10\n", + "number_of_radial_basis_functions = 100\n", + "max_distance = 2.0 / 10\n", + "min_distance = 0.0\n", + "dtype = torch.float32\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define colors for each radial basis function\n", + "colors = ['blue', 'green', 'orange']\n", + "\n", + "for idx, rbf_fn in enumerate([PhysNetRadialBasisFunction, AniRadialBasisFunction, SchnetRadialBasisFunction]):\n", + " print(f\"Testing {rbf_fn.__name__}\")\n", + "\n", + "\n", + " # Instantiate the RBF\n", + " rbf = rbf_fn(\n", + " number_of_radial_basis_functions=number_of_radial_basis_functions,\n", + " max_distance=max_distance,\n", + " min_distance=min_distance,\n", + " dtype=dtype,\n", + " trainable_centers_and_scale_factors=False,\n", + " )\n", + "\n", + " # Get actual outputs\n", + " actual_output = rbf(distances)\n", + "\n", + " import numpy as np\n", + " import matplotlib.pyplot as plt\n", + " rs = torch.tensor([[r] for r in np.linspace(0,0.2, number_of_radial_basis_functions)])\n", + " for i in range(3):\n", + " plt.plot(rs, actual_output[i].numpy(), color=colors[idx])\n", + " # Draw the vertical line (axvline)\n", + " plt.axvline(distances[i].numpy(), 0, 0.2, c='r')\n", + " # Add the legend entry for the radial basis function once\n", + " plt.plot([], [], color=colors[idx], label=f'{rbf_fn.__name__}')\n", + "\n", + "plt.legend()\n", + "plt.show()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "modelforge", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml index ed8b70a9..c9652872 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,7 @@ [build-system] -requires = ["setuptools>=61.0", "versioningit~=2.0"] +requires = ["setuptools>=61.0", "versioningit~=3.0"] build-backend = "setuptools.build_meta" -# Self-descriptive entries which should always be present -# https://packaging.python.org/en/latest/specifications/declaring-project-metadata/ [project] name = "modelforge" description = "Infrastructure to implement and train NNPs" @@ -17,45 +15,27 @@ license = { text = "MIT" } classifiers = [ "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering", + "Operating System :: POSIX :: Linux", + "Environment :: GPU", + "Environment :: GPU :: NVIDIA CUDA", ] -requires-python = ">=3.8" -# Declare any run-time dependencies that should be installed with the package. -#dependencies = [ -# "importlib-resources;python_version<'3.10'", -#] +requires-python = ">=3.10" -# Update the urls once the hosting is set up. -#[project.urls] -#"Source" = "https://github.com//modelforge/" -#"Documentation" = "https://modelforge.readthedocs.io/" - -[project.optional-dependencies] -test = [ - "pytest>=6.1.2", - "pytest-runner" -] +[project.urls] +Source = "https://github.com/choderalab/modelforge" +Documentation = "https://modelforge.readthedocs.io/" +Wiki = "https://github.com/choderalab/modelforge/wiki" [tool.setuptools] -# This subkey is a beta stage development and keys may change in the future, see https://setuptools.pypa.io/en/latest/userguide/pyproject_config.html for more details -# -# As of version 0.971, mypy does not support type checking of installed zipped -# packages (because it does not actually import the Python packages). -# We declare the package not-zip-safe so that our type hints are also available -# when checking client code that uses our (installed) package. -# Ref: +# Disable zipping because mypy cannot read zip imports and this may affect downstream development. # https://mypy.readthedocs.io/en/stable/installed_packages.html?highlight=zip#using-installed-packages-with-mypy-pep-561 +# NOTE: We might consider removing this once we can test the code in a +# production environment since zipping the package may increase performance. zip-safe = false -# Let setuptools discover the package in the current directory, -# but be explicit about non-Python files. -# See also: -# https://setuptools.pypa.io/en/latest/userguide/pyproject_config.html#setuptools-specific-configuration -# Note that behavior is currently evolving with respect to how to interpret the -# "data" and "tests" subdirectories. As of setuptools 63, both are automatically -# included if namespaces is true (default), even if the package is named explicitly -# (instead of using 'find'). With 'find', the 'tests' subpackage is discovered -# recursively because of its __init__.py file, but the data subdirectory is excluded -# with include-package-data = false and namespaces = false. -include-package-data = false +include-package-data = true + [tool.setuptools.packages.find] namespaces = false where = ["."] @@ -66,6 +46,7 @@ modelforge = [ "py.typed" ] +# https://versioningit.readthedocs.io/en/stable/configuration.html# [tool.versioningit] default-version = "1+unknown" @@ -75,9 +56,7 @@ dirty = "{base_version}+{distance}.{vcs}{rev}.dirty" distance-dirty = "{base_version}+{distance}.{vcs}{rev}.dirty" [tool.versioningit.vcs] -# The method key: -method = "git" # <- The method name -# Parameters to pass to the method: +method = "git" match = ["*"] default-tag = "1.0.0" @@ -90,7 +69,7 @@ file = "modelforge/_version.py" omit = [ # Omit the tests "*/tests/*", - # Omit generated versioneer + # Omit generated versioningit "modelforge/_version.py" ] diff --git a/scripts/config.toml b/scripts/config.toml index 1c4379e3..27d118dc 100644 --- a/scripts/config.toml +++ b/scripts/config.toml @@ -1,39 +1,51 @@ [potential] -model_name = "SchNet" +potential_name = "ANI2x" [potential.core_parameter] -max_Z = 101 -number_of_atom_features = 32 -number_of_radial_basis_functions = 20 -cutoff = "5.0 angstrom" -number_of_interaction_modules = 3 -number_of_filters = 32 -shared_interactions = false +angle_sections = 4 +maximum_interaction_radius = "5.1 angstrom" +minimum_interaction_radius = "0.8 angstrom" +number_of_radial_basis_functions = 16 +maximum_interaction_radius_for_angular_features = "3.5 angstrom" +minimum_interaction_radius_for_angular_features = "0.8 angstrom" +angular_dist_divisions = 8 +predicted_properties = ["per_atom_energy"] +predicted_dim = [1] + +[potential.core_parameter.activation_function_parameter] +activation_function_name = "CeLU" # for the original ANI behavior please stick with CeLu since the alpha parameter is currently hard coded and might lead to different behavior when another activation function is used. + +[potential.core_parameter.activation_function_parameter.activation_function_arguments] +alpha = 0.1 [potential.postprocessing_parameter] +properties_to_process = ['per_atom_energy'] [potential.postprocessing_parameter.per_atom_energy] -normalize = true -from_atom_to_molecule_reduction = true +normalize = false +from_atom_to_system_reduction = true keep_per_atom_property = true -[potential.postprocessing_parameter.general_postprocessing_operation] -calculate_molecular_self_energy = true [dataset] -dataset_name = "QM9" +dataset_name = "PHALKETHOH" +version_select = "nc_1000_v0" +num_workers = 4 +pin_memory = true [training] -save_dir = "test" -experiment_name = "your_experiment_name" -accelerator = "cpu" -num_nodes = 1 -devices = 1 # [0,1,2,3] +number_of_epochs = 20 remove_self_energies = true -batch_size = 128 +batch_size = 16 +lr = 0.5e-3 +monitor = "val/per_system_energy/rmse" +shift_center_of_mass_to_origin = false -[training.training_parameter] -lr = 1e-3 +[training.experiment_logger] +logger_name = "tensorboard" -[training.training_parameter.lr_scheduler_config] +[training.experiment_logger.tensorboard_configuration] +save_dir = "logs" +[training.lr_scheduler] +scheduler_name = "ReduceLROnPlateau" frequency = 1 mode = "min" factor = 0.1 @@ -42,34 +54,42 @@ cooldown = 5 min_lr = 1e-8 threshold = 0.1 threshold_mode = "abs" -monitor = "val/per_molecule_energy/rmse" +monitor = "val/per_system_energy/rmse" interval = "epoch" -[training.training_parameter.loss_parameter] -loss_property = ['per_molecule_energy'] # use . -[training.training_parameter.loss_parameter.weight] -per_molecule_energy = 0.999 #NOTE: reciproce units -per_atom_force = 0.001 +[training.loss_parameter] +loss_components = ['per_system_energy', 'per_atom_force'] # use + +[training.loss_parameter.weight] +per_system_energy = 1 +per_atom_force = 0.8 + +[training.loss_parameter.target_weight] +per_atom_force = 0.2 + +[training.loss_parameter.mixing_steps] +per_atom_force = -0.1 + [training.early_stopping] verbose = true -monitor = "val/per_molecule_energy/rmse" +monitor = "val/per_system_energy/rmse" min_delta = 0.001 patience = 50 -[training.experiment_logger] -logger_name = "wandb" -save_dir = "test" -experiment_name = "{model_name}_{dataset_name}" - - [training.splitting_strategy] name = "random_record_splitting_strategy" data_split = [0.8, 0.1, 0.1] +seed = 42 [runtime] -accelerator = "cpu" -num_nodes = 1 -devices = 1 #[0,1,2,3] +save_dir = "test_setup" +experiment_name = "{potential_name}_{dataset_name}" local_cache_dir = "./cache" -nr_of_epochs = 10 +accelerator = "cpu" +number_of_nodes = 1 +devices = 1 #[0,1,2,3] +checkpoint_path = "None" +simulation_environment = "PyTorch" +log_every_n_steps = 1 +verbose = true diff --git a/scripts/perform_training.py b/scripts/perform_training.py index db0f424c..ffbcc649 100644 --- a/scripts/perform_training.py +++ b/scripts/perform_training.py @@ -1,10 +1,61 @@ -# This is an example script that trains an implemented model on the QM9 dataset. -# tensorboard --logdir tb_logs +# This script provides a command line interface to train an neurtal network potential. +import argparse +from modelforge.train.training import read_config_and_train +from modelforge.utils.io import parse_devices -if __name__ == "__main__": +parse = argparse.ArgumentParser(description="Perform Training Using Modelforge") - import fire - from modelforge.train.training import read_config_and_train +parse.add_argument( + "--condensed_config_path", type=str, help="Path to the condensed TOML config file" +) +parse.add_argument( + "--training_parameter_path", type=str, help="Path to the training TOML config file" +) +parse.add_argument( + "--dataset_parameter_path", type=str, help="Path to the dataset TOML config file" +) +parse.add_argument( + "--potential_parameter_path", + type=str, + help="Path to the potential TOML config file", +) +parse.add_argument( + "--runtime_parameter_path", type=str, help="Path to the runtime TOML config file" +) +parse.add_argument("--accelerator", type=str, help="Accelerator to use for training") +parse.add_argument( + "--devices", type=parse_devices, help="Device(s) to use for training" +) +parse.add_argument( + "--number_of_nodes", type=int, help="Number of nodes to use for training" +) +parse.add_argument("--experiment_name", type=str, help="Name of the experiment") +parse.add_argument("--save_dir", type=str, help="Directory to save the model") +parse.add_argument("--local_cache_dir", type=str, help="Local cache directory") +parse.add_argument("--checkpoint_path", type=str, help="Path to the checkpoint") +parse.add_argument("--log_every_n_steps", type=int, help="Log every n steps") +parse.add_argument( + "--simulation_environment", + type=str, + help="Simulation environment to use for training", +) - fire.Fire(read_config_and_train) + +args = parse.parse_args() +read_config_and_train( + condensed_config_path=args.condensed_config_path, + training_parameter_path=args.training_parameter_path, + dataset_parameter_path=args.dataset_parameter_path, + potential_parameter_path=args.potential_parameter_path, + runtime_parameter_path=args.runtime_parameter_path, + accelerator=args.accelerator, + devices=args.devices, + number_of_nodes=args.number_of_nodes, + experiment_name=args.experiment_name, + save_dir=args.save_dir, + local_cache_dir=args.local_cache_dir, + checkpoint_path=args.checkpoint_path, + log_every_n_steps=args.log_every_n_steps, + simulation_environment=args.simulation_environment, +) diff --git a/scripts/profile_epoch.py b/scripts/profile_epoch.py index d4c52657..559de787 100644 --- a/scripts/profile_epoch.py +++ b/scripts/profile_epoch.py @@ -4,7 +4,6 @@ def perform_training(trainer, model, dm): - # Run training loop and validate trainer.fit( model, @@ -13,7 +12,7 @@ def perform_training(trainer, model, dm): ) -def setup(model_name: str): +def setup(potential_name: str): from modelforge.dataset.utils import RandomRecordSplittingStrategy from lightning import Trainer from modelforge.potential import NeuralNetworkPotentialFactory @@ -22,17 +21,17 @@ def setup(model_name: str): from modelforge import tests as modelforge_tests config = return_toml_config( - f"{resources.files(modelforge_tests)}/data/training_defaults/{model_name.lower()}_qm9.toml" + f"{resources.files(modelforge_tests)}/data/training_defaults/{potential_name.lower()}_qm9.toml" ) # Extract parameters potential_config = config["potential"] - training_config = config["training"] + training_config = config["runtime_defaults"] dataset_config = config["dataset"] training_config["nr_of_epochs"] = 1 dataset_config["version_select"] = "nc_1000_v0" - model_name = potential_config["model_name"] + potential_name = potential_config["potential_name"] dataset_name = dataset_config["dataset_name"] version_select = dataset_config.get("version_select", "latest") accelerator = training_config.get("accelerator", "cpu") @@ -64,20 +63,19 @@ def setup(model_name: str): # Set up model model = NeuralNetworkPotentialFactory.generate_model( - use="training", - model_type=model_name, - model_parameters=potential_config["potential_parameters"], - training_parameters=training_config["training_parameters"], + use="runtime_defaults", + potential_name=potential_name, + model_parameters=potential_config["potential"], + training_parameters=training_config["training"], ) return trainer, model, dm if __name__ == "__main__": - - model_name = "SchNet" + potential_name = "SchNet" trainer, model, dm = setup( - model_name=model_name, + potential_name=potential_name, ) perform_training(trainer, model, dm) diff --git a/scripts/profile_network.py b/scripts/profile_network.py index 8ea38ffa..29d783e8 100644 --- a/scripts/profile_network.py +++ b/scripts/profile_network.py @@ -12,19 +12,19 @@ def profile_network(model, data): )[0] -def setup(model_name: str): +def setup(potential_name: str): from importlib import resources from modelforge import tests as modelforge_tests config = return_toml_config( - f"{resources.files(modelforge_tests)}data/training_defaults/{model_name.lower()}_qm9.toml" + f"{resources.files(modelforge_tests)}data/training_defaults/{potential_name.lower()}_qm9.toml" ) # Extract parameters potential_parameter = config["potential"].get("potential_parameter", {}) model = NeuralNetworkPotentialFactory.generate_model( use="inference", - model_type=model_name, + model_type=potential_name, simulation_environment="PyTorch", model_parameters=potential_parameter, ) @@ -54,9 +54,9 @@ def setup(model_name: str): from torch.profiler import profile, record_function, ProfilerActivity import torch - model_name = "SchNet" + potential_name = "SchNet" - model, data = setup(model_name) + model, data = setup(potential_name) with profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], # record_shapes=True, diff --git a/scripts/training_run.sh b/scripts/training_run.sh index e20cc75a..ae89f90b 100644 --- a/scripts/training_run.sh +++ b/scripts/training_run.sh @@ -2,4 +2,4 @@ # training run with a small number of epochs with default # parameters -python perform_training.py config.toml \ No newline at end of file +python perform_training.py --condensed_config_path config.toml diff --git a/scripts/tuning.py b/scripts/tuning.py index a323c491..2b681799 100644 --- a/scripts/tuning.py +++ b/scripts/tuning.py @@ -1,11 +1,8 @@ import torch -from modelforge.utils.io import import_ -tune = import_("ray").tune -air = import_("ray").air -# from ray import tune, air -ASHAScheduler = import_("ray").tune.schedulers.ASHAScheduler -# from ray.tune.schedulers import ASHAScheduler + +from ray import tune, air +from ray.tune.schedulers import ASHAScheduler from modelforge.potential import NeuralNetworkPotentialFactory from modelforge.dataset.qm9 import QM9Dataset diff --git a/setup.py b/setup.py deleted file mode 100644 index c7e538b2..00000000 --- a/setup.py +++ /dev/null @@ -1,21 +0,0 @@ -from setuptools import setup - -setup( - name="modelforge", - version="0.1", - packages=["modelforge"], - package_data={ - "modelforge": [ - "dataset/yaml_files/*", - "curation/yaml_files/*", - "tests/data/potential_defaults/*", - "tests/data/training_defaults/*", - ] - }, - url="https://github.com/choderalab/modelforge", - license="MIT", - author="Chodera lab, Marcus Wieder, Christopher Iacovella, and others", - author_email="", - description="A library for building and training neural network potentials", - include_package_data=True, -)