From faf9c9b1b3d7cd55ee2252d45493ae62846c30ad Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Wed, 29 May 2024 11:06:16 -0700 Subject: [PATCH 01/78] Copy SchNet implementation and rename to SpookyNet --- modelforge/potential/__init__.py | 3 + modelforge/potential/spookynet.py | 429 +++++++++++++++++++++++++++++ modelforge/tests/test_spookynet.py | 50 ++++ 3 files changed, 482 insertions(+) create mode 100644 modelforge/potential/spookynet.py create mode 100644 modelforge/tests/test_spookynet.py diff --git a/modelforge/potential/__init__.py b/modelforge/potential/__init__.py index c3f3a6f0..ae27f060 100644 --- a/modelforge/potential/__init__.py +++ b/modelforge/potential/__init__.py @@ -3,6 +3,8 @@ from .painn import PaiNN from .ani import ANI2x from .sake import SAKE +from .spookynet import SpookyNet + from .utils import ( CosineCutoff, RadialSymmetryFunction, @@ -20,6 +22,7 @@ class _Implemented_NNPs(Enum): PAINN = PaiNN PHYSNET = PhysNet SAKE = SAKE + SPOOKYNET = SpookyNet @classmethod def get_neural_network_class(cls, neural_network_name: str): diff --git a/modelforge/potential/spookynet.py b/modelforge/potential/spookynet.py new file mode 100644 index 00000000..b426fe34 --- /dev/null +++ b/modelforge/potential/spookynet.py @@ -0,0 +1,429 @@ +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Dict, Optional + +import torch +import torch.nn as nn +from loguru import logger as log +from openff.units import unit + +from .models import CoreNetwork + +if TYPE_CHECKING: + from .models import PairListOutputs + from modelforge.potential.utils import NNPInput + +from modelforge.potential.utils import NeuralNetworkData + + +@dataclass +class SpookyNetNeuralNetworkData(NeuralNetworkData): + """ + A dataclass to structure the inputs specifically for SpookyNet-based neural network potentials, including the necessary + geometric and chemical information, along with the radial symmetry function expansion (`f_ij`) and the cosine cutoff + (`f_cutoff`) to accurately represent atomistic systems for energy predictions. + + Attributes + ---------- + pair_indices : torch.Tensor + A 2D tensor of shape [2, num_pairs], indicating the indices of atom pairs within a molecule or system. + d_ij : torch.Tensor + A 1D tensor containing the distances between each pair of atoms identified in `pair_indices`. Shape: [num_pairs, 1]. + r_ij : torch.Tensor + A 2D tensor of shape [num_pairs, 3], representing the displacement vectors between each pair of atoms. + number_of_atoms : int + A integer indicating the number of atoms in the batch. + positions : torch.Tensor + A 2D tensor of shape [num_atoms, 3], representing the XYZ coordinates of each atom within the system. + atomic_numbers : torch.Tensor + A 1D tensor containing atomic numbers for each atom, used to identify the type of each atom in the system(s). + atomic_subsystem_indices : torch.Tensor + A 1D tensor mapping each atom to its respective subsystem or molecule, useful for systems involving multiple + molecules or distinct subsystems. + total_charge : torch.Tensor + A tensor with the total charge of each system or molecule. Shape: [num_systems], where each entry corresponds + to a distinct system or molecule. + atomic_embedding : torch.Tensor + A 2D tensor containing embeddings or features for each atom, derived from atomic numbers. + Shape: [num_atoms, embedding_dim], where `embedding_dim` is the dimensionality of the embedding vectors. + f_ij : Optional[torch.Tensor] + A tensor representing the radial symmetry function expansion of distances between atom pairs, capturing the + local chemical environment. Shape: [num_pairs, num_features], where `num_features` is the dimensionality of + the radial symmetry function expansion. This field will be populated after initialization. + f_cutoff : Optional[torch.Tensor] + A tensor representing the cosine cutoff function applied to the radial symmetry function expansion, ensuring + that atom pair contributions diminish smoothly to zero at the cutoff radius. Shape: [num_pairs]. This field + will be populated after initialization. + + Notes + ----- + The `SpookyNetNeuralNetworkData` class is designed to encapsulate all necessary inputs for SpookyNet-based neural network + potentials in a structured and type-safe manner, facilitating efficient and accurate processing of input data by + the model. The inclusion of radial symmetry functions (`f_ij`) and cosine cutoff functions (`f_cutoff`) allows + for a detailed and nuanced representation of the atomistic systems, crucial for the accurate prediction of system + energies and properties. + + Examples + -------- + >>> inputs = SpookyNetNeuralNetworkData( + ... pair_indices=torch.tensor([[0, 1], [0, 2], [1, 2]]), + ... d_ij=torch.tensor([1.0, 1.0, 1.0]), + ... r_ij=torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]), + ... number_of_atoms=3, + ... positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 0.0]]), + ... atomic_numbers=torch.tensor([1, 6, 8]), + ... atomic_subsystem_indices=torch.tensor([0, 0, 0]), + ... total_charge=torch.tensor([0.0]), + ... atomic_embedding=torch.randn(3, 5), # Example atomic embeddings + ... f_ij=torch.randn(3, 4), # Example radial symmetry function expansion + ... f_cutoff=torch.tensor([0.5, 0.5, 0.5]) # Example cosine cutoff function + ... ) + """ + + atomic_embedding: torch.Tensor + f_ij: Optional[torch.Tensor] = field(default=None) + f_cutoff: Optional[torch.Tensor] = field(default=None) + + +class SpookyNetCore(CoreNetwork): + def __init__( + self, + max_Z: int = 100, + number_of_atom_features: int = 64, + number_of_radial_basis_functions: int = 20, + number_of_interaction_modules: int = 3, + number_of_filters: int = 64, + shared_interactions: bool = False, + cutoff: unit.Quantity = 5.0 * unit.angstrom, + ) -> None: + """ + Initialize the SpookyNet class. + + Parameters + ---------- + max_Z : int, default=100 + Maximum atomic number to be embedded. + number_of_atom_features : int, default=64 + Dimension of the embedding vectors for atomic numbers. + number_of_radial_basis_functions:int, default=16 + number_of_interaction_modules : int, default=2 + cutoff : openff.units.unit.Quantity, default=5*unit.angstrom + The cutoff distance for interactions. + """ + from .utils import Dense, ShiftedSoftplus + + log.debug("Initializing SpookyNet model.") + super().__init__(cutoff) + self.number_of_atom_features = number_of_atom_features + self.number_of_filters = number_of_filters or self.number_of_atom_features + self.number_of_radial_basis_functions = number_of_radial_basis_functions + + # embedding + from modelforge.potential.utils import Embedding + + self.embedding_module = Embedding(max_Z, number_of_atom_features) + + # initialize the energy readout + from .processing import FromAtomToMoleculeReduction + + self.readout_module = FromAtomToMoleculeReduction() + + # Initialize representation block + self.spookynet_representation_module = SpookyNetRepresentation( + cutoff, number_of_radial_basis_functions + ) + # Intialize interaction blocks + self.interaction_modules = nn.ModuleList( + [ + SpookyNetInteractionModule( + self.number_of_atom_features, + self.number_of_filters, + number_of_radial_basis_functions, + ) + for _ in range(number_of_interaction_modules) + ] + ) + + # final output layer + self.energy_layer = nn.Sequential( + Dense( + number_of_atom_features, + number_of_atom_features, + activation=ShiftedSoftplus(), + ), + Dense( + number_of_atom_features, + 1, + ), + ) + + def _model_specific_input_preparation( + self, data: "NNPInput", pairlist_output: "PairListOutputs" + ) -> SpookyNetNeuralNetworkData: + number_of_atoms = data.atomic_numbers.shape[0] + + nnp_input = SpookyNetNeuralNetworkData( + pair_indices=pairlist_output.pair_indices, + d_ij=pairlist_output.d_ij, + r_ij=pairlist_output.r_ij, + number_of_atoms=number_of_atoms, + positions=data.positions, + atomic_numbers=data.atomic_numbers, + atomic_subsystem_indices=data.atomic_subsystem_indices, + total_charge=data.total_charge, + atomic_embedding=self.embedding_module( + data.atomic_numbers + ), # atom embedding + ) + + return nnp_input + + def _forward(self, data: SpookyNetNeuralNetworkData) -> Dict[str, torch.Tensor]: + """ + Calculate the energy for a given input batch. + + Parameters + ---------- + data : NamedTuple + + Returns + ------- + Dict[str, torch.Tensor] + Calculated energies; shape (nr_systems,). + """ + + # Compute the representation for each atom (transform to radial basis set, multiply by cutoff) + representation = self.spookynet_representation_module(data.d_ij) + data.f_ij = representation["f_ij"] + data.f_cutoff = representation["f_cutoff"] + x = data.atomic_embedding + # Iterate over interaction blocks to update features + for interaction in self.interaction_modules: + v = interaction( + x, + data.pair_indices, + representation["f_ij"], + representation["f_cutoff"], + ) + x = x + v # Update atomic features + + E_i = self.energy_layer(x).squeeze(1) + + return { + "E_i": E_i, + "q": x, + "atomic_subsystem_indices": data.atomic_subsystem_indices, + } + + +from torch_scatter import scatter_add + + +class SpookyNetInteractionModule(nn.Module): + def __init__( + self, + number_of_atom_features: int, + number_of_filters: int, + number_of_radial_basis_functions: int, + ) -> None: + """ + Initialize the SpookyNet interaction block. + + Parameters + ---------- + number_of_atom_features : int + Number of atom ffeatures, defines the dimensionality of the embedding. + number_of_filters : int + Number of filters, defines the dimensionality of the intermediate features. + number_of_radial_basis_functions : int + Number of radial basis functions. + """ + super().__init__() + from .utils import Dense, ShiftedSoftplus + + assert ( + number_of_radial_basis_functions > 4 + ), "Number of radial basis functions must be larger than 10." + assert number_of_filters > 1, "Number of filters must be larger than 1." + assert ( + number_of_atom_features > 10 + ), "Number of atom basis must be larger than 10." + + self.number_of_atom_features = number_of_atom_features # Initialize parameters + self.intput_to_feature = Dense( + number_of_atom_features, number_of_filters, bias=False, activation=None + ) + self.feature_to_output = nn.Sequential( + Dense( + number_of_filters, number_of_atom_features, activation=ShiftedSoftplus() + ), + Dense(number_of_atom_features, number_of_atom_features, activation=None), + ) + self.filter_network = nn.Sequential( + Dense( + number_of_radial_basis_functions, + number_of_filters, + activation=ShiftedSoftplus(), + ), + Dense(number_of_filters, number_of_filters, activation=None), + ) + + def forward( + self, + x: torch.Tensor, + pairlist: torch.Tensor, # shape [n_pairs, 2] + f_ij: torch.Tensor, # shape [n_pairs, 1, number_of_radial_basis_functions] + f_ij_cutoff: torch.Tensor, # shape [n_pairs, 1] + ) -> torch.Tensor: + """ + Forward pass for the interaction block. + + Parameters + ---------- + x : torch.Tensor, shape [nr_of_atoms_in_systems, nr_atom_basis] + Input feature tensor for atoms. + pairlist : torch.Tensor, shape [n_pairs, 2] + f_ij : torch.Tensor, shape [n_pairs, 1, number_of_radial_basis_functions] + Radial basis functions for pairs of atoms. + f_ij_cutoff : torch.Tensor, shape [n_pairs, 1] + + Returns + ------- + torch.Tensor, shape [nr_of_atoms_in_systems, nr_atom_basis] + Updated feature tensor after interaction block. + """ + idx_i, idx_j = pairlist[0], pairlist[1] + + # Map input features to the filter space + x = self.intput_to_feature(x) + + # Generate interaction filters based on radial basis functions + W_ij = self.filter_network(f_ij.squeeze(1)) + W_ij = W_ij * f_ij_cutoff + + # Perform continuous-filter convolution + x_j = x[idx_j] + x_ij = x_j * W_ij + x = scatter_add(x_ij, idx_i, dim=0, dim_size=x.size(0)) + + return self.feature_to_output(x) + + +class SpookyNetRepresentation(nn.Module): + def __init__( + self, + radial_cutoff: unit.Quantity, + number_of_radial_basis_functions: int, + ): + """ + Initialize the SpookyNet representation layer. + + Parameters + ---------- + Radial Basis Function Module + """ + super().__init__() + + self.radial_symmetry_function_module = self._setup_radial_symmetry_functions( + radial_cutoff, number_of_radial_basis_functions + ) + # cutoff + from modelforge.potential import CosineCutoff + + self.cutoff_module = CosineCutoff(radial_cutoff) + + def _setup_radial_symmetry_functions( + self, radial_cutoff: unit.Quantity, number_of_radial_basis_functions: int + ): + from .utils import SchnetRadialSymmetryFunction + + radial_symmetry_function = SchnetRadialSymmetryFunction( + number_of_radial_basis_functions=number_of_radial_basis_functions, + max_distance=radial_cutoff, + dtype=torch.float32, + ) + return radial_symmetry_function + + def forward(self, d_ij: torch.Tensor) -> Dict[str, torch.Tensor]: + """ + Generate the radial symmetry representation of the pairwise distances. + + Parameters + ---------- + d_ij : Pairwise distances between atoms; shape [n_pairs, 1] + + Returns + ------- + Radial basis functions for pairs of atoms; shape [n_pairs, 1, number_of_radial_basis_functions] + """ + + # Convert distances to radial basis functions + f_ij = self.radial_symmetry_function_module( + d_ij + ) # shape (n_pairs, 1, number_of_radial_basis_functions) + + f_cutoff = self.cutoff_module(d_ij) # shape (n_pairs, 1) + + return {"f_ij": f_ij, "f_cutoff": f_cutoff} + + +from .models import InputPreparation, NNPInput, BaseNetwork + + +class SpookyNet(BaseNetwork): + def __init__( + self, + max_Z: int = 101, + number_of_atom_features: int = 32, + number_of_radial_basis_functions: int = 20, + number_of_interaction_modules: int = 3, + cutoff: unit.Quantity = 5 * unit.angstrom, + number_of_filters: int = 32, + shared_interactions: bool = False, + ) -> None: + """ + Initialize the SpookyNet network. + + Unke, O.T., Chmiela, S., Gastegger, M. et al. SpookyNet: Learning force fields with electronic degrees of + freedom and nonlocal effects. Nat Commun 12, 7273 (2021). + + Parameters + ---------- + max_Z : int, default=100 + Maximum atomic number to be embedded. + number_of_atom_features : int, default=64 + Dimension of the embedding vectors for atomic numbers. + number_of_radial_basis_functions:int, default=16 + number_of_interaction_modules : int, default=2 + cutoff : openff.units.unit.Quantity, default=5*unit.angstrom + The cutoff distance for interactions. + """ + super().__init__() + self.core_module = SpookyNetCore( + max_Z=max_Z, + number_of_atom_features=number_of_atom_features, + number_of_radial_basis_functions=number_of_radial_basis_functions, + number_of_interaction_modules=number_of_interaction_modules, + number_of_filters=number_of_filters, + shared_interactions=shared_interactions, + ) + self.only_unique_pairs = False # NOTE: for pairlist + self.input_preparation = InputPreparation( + cutoff=cutoff, only_unique_pairs=self.only_unique_pairs + ) + + def _config_prior(self): + log.info("Configuring SpookyNet model hyperparameter prior distribution") + from ray import tune + + from modelforge.potential.utils import shared_config_prior + + prior = { + "number_of_atom_features": tune.randint(2, 256), + "number_of_interaction_modules": tune.randint(1, 5), + "cutoff": tune.uniform(5, 10), + "number_of_radial_basis_functions": tune.randint(8, 32), + "number_of_filters": tune.randint(32, 128), + "shared_interactions": tune.choice([True, False]), + } + prior.update(shared_config_prior()) + return prior diff --git a/modelforge/tests/test_spookynet.py b/modelforge/tests/test_spookynet.py new file mode 100644 index 00000000..fae16370 --- /dev/null +++ b/modelforge/tests/test_spookynet.py @@ -0,0 +1,50 @@ +from modelforge.potential.spookynet import SpookyNet + +import pytest + + +def test_spookynet_init(): + """Test initialization of the SpookyNet model.""" + + spookynet = SpookyNet() + assert spookynet is not None, "SpookyNet model should be initialized." + + +from openff.units import unit + + +@pytest.mark.parametrize( + "model_parameter", + ( + [64, 50, 20, unit.Quantity(5.0, unit.angstrom), 2], + [32, 60, 10, unit.Quantity(7.0, unit.angstrom), 1], + [128, 120, 64, unit.Quantity(5.0, unit.angstrom), 3], + ), +) +def test_spookynet_forward(single_batch_with_batchsize_64, model_parameter): + """ + Test the forward pass of the SpookyNet model. + """ + print(f"model_parameter: {model_parameter}") + ( + nr_atom_basis, + max_atomic_number, + number_of_gaussians, + cutoff, + nr_interaction_blocks, + ) = model_parameter + spookynet = SpookyNet( + number_of_atom_features=nr_atom_basis, + max_Z=max_atomic_number, + number_of_radial_basis_functions=number_of_gaussians, + cutoff=cutoff, + number_of_interaction_modules=nr_interaction_blocks, + ) + energy = spookynet(single_batch_with_batchsize_64.nnp_input).E + nr_of_mols = single_batch_with_batchsize_64.nnp_input.atomic_subsystem_indices.unique().shape[ + 0 + ] + + assert ( + len(energy) == nr_of_mols + ) # Assuming energy is calculated per sample in the batch From 6536b596d42a6a522fe3493036e77897a7754b23 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Fri, 31 May 2024 15:49:14 -0700 Subject: [PATCH 02/78] Add spookynet to test environment and list of models. Rename SAKE input dataclass. --- devtools/conda-envs/test_env.yaml | 1 + modelforge/potential/__init__.py | 1 - modelforge/potential/models.py | 10 ++++--- modelforge/potential/sake.py | 12 ++++---- modelforge/tests/test_spookynet.py | 46 +++++++++++++++++++++++++++--- 5 files changed, 55 insertions(+), 15 deletions(-) diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index c004d10e..bc55e7f4 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -46,5 +46,6 @@ dependencies: - flax - torchviz - git+https://github.com/ArnNag/sake.git@nanometer + - git+https://github.com/OUnke/SpookyNet.git - tensorflow - torchviz diff --git a/modelforge/potential/__init__.py b/modelforge/potential/__init__.py index ae27f060..794ec0c9 100644 --- a/modelforge/potential/__init__.py +++ b/modelforge/potential/__init__.py @@ -4,7 +4,6 @@ from .ani import ANI2x from .sake import SAKE from .spookynet import SpookyNet - from .utils import ( CosineCutoff, RadialSymmetryFunction, diff --git a/modelforge/potential/models.py b/modelforge/potential/models.py index 4bc52604..f1598efb 100644 --- a/modelforge/potential/models.py +++ b/modelforge/potential/models.py @@ -14,8 +14,9 @@ from modelforge.potential.ani import ANI2x, AniNeuralNetworkData from modelforge.potential.painn import PaiNN, PaiNNNeuralNetworkData from modelforge.potential.physnet import PhysNet, PhysNetNeuralNetworkData - from modelforge.potential.sake import SAKE, SAKENeuralNetworkInput + from modelforge.potential.sake import SAKE, SAKENeuralNetworkData from modelforge.potential.schnet import SchNet, SchnetNeuralNetworkData + from modelforge.potential.spookynet import SpookyNet, SpookyNetNeuralNetworkData # Define NamedTuple for the outputs of Pairlist and Neighborlist forward method @@ -317,7 +318,7 @@ class PyTorch2JAXConverter: """ def convert_to_jax_model( - self, nnp_instance: Union["ANI2x", "SchNet", "PaiNN", "PhysNet"] + self, nnp_instance: Union["ANI2x", "SchNet", "PaiNN", "PhysNet", "SAKE", "SpookyNet"] ) -> JAXModel: """ Convert a PyTorch neural network instance to a JAX model. @@ -624,7 +625,7 @@ def _model_specific_input_preparation( "PaiNNNeuralNetworkData", "SchnetNeuralNetworkData", "AniNeuralNetworkData", - "SAKENeuralNetworkInput", + "SAKENeuralNetworkData", ]: """ Prepares model-specific inputs before the forward pass. @@ -655,7 +656,8 @@ def _forward( "PaiNNNeuralNetworkData", "SchnetNeuralNetworkData", "AniNeuralNetworkData", - "SAKENeuralNetworkInput", + "SAKENeuralNetworkData", + "SpookyNetNeuralNetworkData", ], ): """ diff --git a/modelforge/potential/sake.py b/modelforge/potential/sake.py index 78d74f56..6aead159 100644 --- a/modelforge/potential/sake.py +++ b/modelforge/potential/sake.py @@ -18,7 +18,7 @@ @dataclass -class SAKENeuralNetworkInput: +class SAKENeuralNetworkData: """ A dataclass designed to structure the inputs for SAKE neural network potentials, ensuring an efficient and structured representation of atomic systems for energy computation and @@ -48,7 +48,7 @@ class SAKENeuralNetworkInput: Examples -------- - >>> sake_input = SAKENeuralNetworkInput( + >>> sake_input = SAKENeuralNetworkData( ... atomic_numbers=torch.tensor([1, 6, 6, 8]), ... positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]]), ... atomic_subsystem_indices=torch.tensor([0, 0, 0, 0]), @@ -122,7 +122,7 @@ def __init__( def _model_specific_input_preparation( self, data: "NNPInput", pairlist_output: "PairListOutputs" - ) -> SAKENeuralNetworkInput: + ) -> SAKENeuralNetworkData: # Perform atomic embedding number_of_atoms = data.atomic_numbers.shape[0] @@ -133,7 +133,7 @@ def _model_specific_input_preparation( ) ) - nnp_input = SAKENeuralNetworkInput( + nnp_input = SAKENeuralNetworkData( pair_indices=pairlist_output.pair_indices, number_of_atoms=number_of_atoms, positions=data.positions.to(self.embedding.weight.dtype), @@ -144,13 +144,13 @@ def _model_specific_input_preparation( return nnp_input - def _forward(self, data: SAKENeuralNetworkInput): + def _forward(self, data: SAKENeuralNetworkData): """ Compute atomic representations/embeddings. Parameters ---------- - data: SAKENeuralNetworkInput + data: SAKENeuralNetworkData Dataclass containing atomic properties, embeddings, and pairlist. Returns diff --git a/modelforge/tests/test_spookynet.py b/modelforge/tests/test_spookynet.py index fae16370..9bcf9906 100644 --- a/modelforge/tests/test_spookynet.py +++ b/modelforge/tests/test_spookynet.py @@ -1,4 +1,6 @@ from modelforge.potential.spookynet import SpookyNet +from spookynet import SpookyNet as RefSpookyNet +import torch import pytest @@ -16,9 +18,9 @@ def test_spookynet_init(): @pytest.mark.parametrize( "model_parameter", ( - [64, 50, 20, unit.Quantity(5.0, unit.angstrom), 2], - [32, 60, 10, unit.Quantity(7.0, unit.angstrom), 1], - [128, 120, 64, unit.Quantity(5.0, unit.angstrom), 3], + [64, 50, 20, unit.Quantity(5.0, unit.angstrom), 2], + [32, 60, 10, unit.Quantity(7.0, unit.angstrom), 1], + [128, 120, 64, unit.Quantity(5.0, unit.angstrom), 3], ), ) def test_spookynet_forward(single_batch_with_batchsize_64, model_parameter): @@ -46,5 +48,41 @@ def test_spookynet_forward(single_batch_with_batchsize_64, model_parameter): ] assert ( - len(energy) == nr_of_mols + len(energy) == nr_of_mols ) # Assuming energy is calculated per sample in the batch + + +def make_random_pairlist(nr_atoms, nr_pairs, include_self_pairs): + if include_self_pairs: + nr_pairs_choose = nr_pairs - nr_atoms + assert nr_pairs_choose >= 0, """Number of pairs must be greater than or equal to the number of atoms if " + include_self_pairs is True.""" + + else: + nr_pairs_choose = nr_pairs + + all_pairs = torch.cartesian_prod(torch.arange(nr_atoms), torch.arange(nr_atoms)) + self_pairs = all_pairs.T[0] == all_pairs.T[1] + non_self_pairs = all_pairs[~self_pairs] + perm = torch.randperm(non_self_pairs.size(0)) + idx = perm[:nr_pairs_choose] + pairlist = non_self_pairs[idx] + if include_self_pairs: + pairlist = torch.cat( + [pairlist, all_pairs[self_pairs]], dim=0 + ) + + return pairlist.T + + +def test_atomic_properties_static(): + ref_spookynet = RefSpookyNet() + + nr_atoms = 5 + geometry_basis = 3 + nr_pairs = 7 + idx_i, idx_j = make_random_pairlist(nr_atoms, nr_pairs, False) + + Z = torch.randint(1, 100, (nr_atoms,)) + R = torch.rand((nr_atoms, geometry_basis)) + print(ref_spookynet._atomic_properties_static(Z, R, idx_i, idx_j)) From a6e8ce54193ac80eaf128e97d54b713832777851 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Tue, 4 Jun 2024 17:02:53 -0700 Subject: [PATCH 03/78] Copy SpookyNet interaction module code from reference implementation. Modifications: hardcoded SiLU activation function. Hardcoded approximate attention. Hardcoded initialization of weights of second linear layer in residual block to 0. --- modelforge/potential/spookynet.py | 814 +++++++++++++++++++++++------ modelforge/tests/test_spookynet.py | 32 ++ 2 files changed, 673 insertions(+), 173 deletions(-) diff --git a/modelforge/potential/spookynet.py b/modelforge/potential/spookynet.py index b426fe34..cf48cf6a 100644 --- a/modelforge/potential/spookynet.py +++ b/modelforge/potential/spookynet.py @@ -1,8 +1,11 @@ +import math from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Dict, Optional, Tuple +import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F from loguru import logger as log from openff.units import unit @@ -86,14 +89,14 @@ class SpookyNetNeuralNetworkData(NeuralNetworkData): class SpookyNetCore(CoreNetwork): def __init__( - self, - max_Z: int = 100, - number_of_atom_features: int = 64, - number_of_radial_basis_functions: int = 20, - number_of_interaction_modules: int = 3, - number_of_filters: int = 64, - shared_interactions: bool = False, - cutoff: unit.Quantity = 5.0 * unit.angstrom, + self, + max_Z: int = 100, + number_of_atom_features: int = 64, + number_of_radial_basis_functions: int = 20, + number_of_interaction_modules: int = 3, + number_of_filters: int = 64, + shared_interactions: bool = False, + cutoff: unit.Quantity = 5.0 * unit.angstrom, ) -> None: """ Initialize the SpookyNet class. @@ -127,10 +130,6 @@ def __init__( self.readout_module = FromAtomToMoleculeReduction() - # Initialize representation block - self.spookynet_representation_module = SpookyNetRepresentation( - cutoff, number_of_radial_basis_functions - ) # Intialize interaction blocks self.interaction_modules = nn.ModuleList( [ @@ -157,7 +156,7 @@ def __init__( ) def _model_specific_input_preparation( - self, data: "NNPInput", pairlist_output: "PairListOutputs" + self, data: "NNPInput", pairlist_output: "PairListOutputs" ) -> SpookyNetNeuralNetworkData: number_of_atoms = data.atomic_numbers.shape[0] @@ -215,170 +214,19 @@ def _forward(self, data: SpookyNetNeuralNetworkData) -> Dict[str, torch.Tensor]: } -from torch_scatter import scatter_add - - -class SpookyNetInteractionModule(nn.Module): - def __init__( - self, - number_of_atom_features: int, - number_of_filters: int, - number_of_radial_basis_functions: int, - ) -> None: - """ - Initialize the SpookyNet interaction block. - - Parameters - ---------- - number_of_atom_features : int - Number of atom ffeatures, defines the dimensionality of the embedding. - number_of_filters : int - Number of filters, defines the dimensionality of the intermediate features. - number_of_radial_basis_functions : int - Number of radial basis functions. - """ - super().__init__() - from .utils import Dense, ShiftedSoftplus - - assert ( - number_of_radial_basis_functions > 4 - ), "Number of radial basis functions must be larger than 10." - assert number_of_filters > 1, "Number of filters must be larger than 1." - assert ( - number_of_atom_features > 10 - ), "Number of atom basis must be larger than 10." - - self.number_of_atom_features = number_of_atom_features # Initialize parameters - self.intput_to_feature = Dense( - number_of_atom_features, number_of_filters, bias=False, activation=None - ) - self.feature_to_output = nn.Sequential( - Dense( - number_of_filters, number_of_atom_features, activation=ShiftedSoftplus() - ), - Dense(number_of_atom_features, number_of_atom_features, activation=None), - ) - self.filter_network = nn.Sequential( - Dense( - number_of_radial_basis_functions, - number_of_filters, - activation=ShiftedSoftplus(), - ), - Dense(number_of_filters, number_of_filters, activation=None), - ) - - def forward( - self, - x: torch.Tensor, - pairlist: torch.Tensor, # shape [n_pairs, 2] - f_ij: torch.Tensor, # shape [n_pairs, 1, number_of_radial_basis_functions] - f_ij_cutoff: torch.Tensor, # shape [n_pairs, 1] - ) -> torch.Tensor: - """ - Forward pass for the interaction block. - - Parameters - ---------- - x : torch.Tensor, shape [nr_of_atoms_in_systems, nr_atom_basis] - Input feature tensor for atoms. - pairlist : torch.Tensor, shape [n_pairs, 2] - f_ij : torch.Tensor, shape [n_pairs, 1, number_of_radial_basis_functions] - Radial basis functions for pairs of atoms. - f_ij_cutoff : torch.Tensor, shape [n_pairs, 1] - - Returns - ------- - torch.Tensor, shape [nr_of_atoms_in_systems, nr_atom_basis] - Updated feature tensor after interaction block. - """ - idx_i, idx_j = pairlist[0], pairlist[1] - - # Map input features to the filter space - x = self.intput_to_feature(x) - - # Generate interaction filters based on radial basis functions - W_ij = self.filter_network(f_ij.squeeze(1)) - W_ij = W_ij * f_ij_cutoff - - # Perform continuous-filter convolution - x_j = x[idx_j] - x_ij = x_j * W_ij - x = scatter_add(x_ij, idx_i, dim=0, dim_size=x.size(0)) - - return self.feature_to_output(x) - - -class SpookyNetRepresentation(nn.Module): - def __init__( - self, - radial_cutoff: unit.Quantity, - number_of_radial_basis_functions: int, - ): - """ - Initialize the SpookyNet representation layer. - - Parameters - ---------- - Radial Basis Function Module - """ - super().__init__() - - self.radial_symmetry_function_module = self._setup_radial_symmetry_functions( - radial_cutoff, number_of_radial_basis_functions - ) - # cutoff - from modelforge.potential import CosineCutoff - - self.cutoff_module = CosineCutoff(radial_cutoff) - - def _setup_radial_symmetry_functions( - self, radial_cutoff: unit.Quantity, number_of_radial_basis_functions: int - ): - from .utils import SchnetRadialSymmetryFunction - - radial_symmetry_function = SchnetRadialSymmetryFunction( - number_of_radial_basis_functions=number_of_radial_basis_functions, - max_distance=radial_cutoff, - dtype=torch.float32, - ) - return radial_symmetry_function - - def forward(self, d_ij: torch.Tensor) -> Dict[str, torch.Tensor]: - """ - Generate the radial symmetry representation of the pairwise distances. - - Parameters - ---------- - d_ij : Pairwise distances between atoms; shape [n_pairs, 1] - - Returns - ------- - Radial basis functions for pairs of atoms; shape [n_pairs, 1, number_of_radial_basis_functions] - """ - - # Convert distances to radial basis functions - f_ij = self.radial_symmetry_function_module( - d_ij - ) # shape (n_pairs, 1, number_of_radial_basis_functions) - - f_cutoff = self.cutoff_module(d_ij) # shape (n_pairs, 1) - - return {"f_ij": f_ij, "f_cutoff": f_cutoff} - - from .models import InputPreparation, NNPInput, BaseNetwork class SpookyNet(BaseNetwork): def __init__( - self, - max_Z: int = 101, - number_of_atom_features: int = 32, - number_of_radial_basis_functions: int = 20, - number_of_interaction_modules: int = 3, - cutoff: unit.Quantity = 5 * unit.angstrom, - number_of_filters: int = 32, - shared_interactions: bool = False, + self, + max_Z: int = 101, + number_of_atom_features: int = 32, + number_of_radial_basis_functions: int = 20, + number_of_interaction_modules: int = 3, + cutoff: unit.Quantity = 5 * unit.angstrom, + number_of_filters: int = 32, + shared_interactions: bool = False, ) -> None: """ Initialize the SpookyNet network. @@ -427,3 +275,623 @@ def _config_prior(self): } prior.update(shared_config_prior()) return prior + + +class Swish(nn.Module): + """ + Swish activation function with learnable feature-wise parameters: + f(x) = alpha*x * sigmoid(beta*x) + sigmoid(x) = 1/(1 + exp(-x)) + For beta -> 0 : f(x) -> 0.5*alpha*x + For beta -> inf: f(x) -> max(0, alpha*x) + + Arguments: + num_features (int): + Dimensions of feature space. + initial_alpha (float): + Initial "scale" alpha of the "linear component". + initial_beta (float): + Initial "temperature" of the "sigmoid component". The default value + of 1.702 has the effect of initializing swish to an approximation + of the Gaussian Error Linear Unit (GELU) activation function from + Hendrycks, Dan, and Gimpel, Kevin. "Gaussian error linear units + (GELUs)." + """ + + def __init__( + self, num_features: int, initial_alpha: float = 1.0, initial_beta: float = 1.702 + ) -> None: + """ Initializes the Swish class. """ + super(Swish, self).__init__() + self.initial_alpha = initial_alpha + self.initial_beta = initial_beta + self.register_parameter("alpha", nn.Parameter(torch.Tensor(num_features))) + self.register_parameter("beta", nn.Parameter(torch.Tensor(num_features))) + self.reset_parameters() + + def reset_parameters(self) -> None: + """ Initialize parameters alpha and beta. """ + nn.init.constant_(self.alpha, self.initial_alpha) + nn.init.constant_(self.beta, self.initial_beta) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Evaluate activation function given the input features x. + num_features: Dimensions of feature space. + + Arguments: + x (FloatTensor [:, num_features]): + Input features. + + Returns: + y (FloatTensor [:, num_features]): + Activated features. + """ + return self.alpha * F.silu(self.beta * x) + + +class SpookyNetResidual(nn.Module): + """ + Pre-activation residual block inspired by He, Kaiming, et al. "Identity + mappings in deep residual networks.". + + Arguments: + num_features (int): + Dimensions of feature space. + """ + + def __init__( + self, + num_features: int, + bias: bool = True, + ) -> None: + """ Initializes the Residual class. """ + super(SpookyNetResidual, self).__init__() + # initialize attributes + self.activation1 = Swish(num_features) + self.linear1 = nn.Linear(num_features, num_features, bias=bias) + self.activation2 = Swish(num_features) + self.linear2 = nn.Linear(num_features, num_features, bias=bias) + self.reset_parameters(bias) + + def reset_parameters(self, bias: bool = True) -> None: + """ Initialize parameters to compute an identity mapping. """ + nn.init.orthogonal_(self.linear1.weight) + nn.init.zeros_(self.linear2.weight) + if bias: + nn.init.zeros_(self.linear1.bias) + nn.init.zeros_(self.linear2.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Apply residual block to input atomic features. + N: Number of atoms. + num_features: Dimensions of feature space. + + Arguments: + x (FloatTensor [N, num_features]): + Input feature representations of atoms. + + Returns: + y (FloatTensor [N, num_features]): + Output feature representations of atoms. + """ + y = self.activation1(x) + y = self.linear1(y) + y = self.activation2(y) + y = self.linear2(y) + return x + y + + +class SpookyNetResidualStack(nn.Module): + """ + Stack of num_blocks pre-activation residual blocks evaluated in sequence. + + Arguments: + num_features (int): + Dimensions of feature space. + num_residual (int): + Number of residual blocks to be stacked in sequence. + """ + + def __init__( + self, + num_features: int, + num_residual: int, + bias: bool = True, + ) -> None: + """ Initializes the ResidualStack class. """ + super(SpookyNetResidualStack, self).__init__() + self.stack = nn.ModuleList( + [ + SpookyNetResidual(num_features, bias) + for i in range(num_residual) + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Applies all residual blocks to input features in sequence. + N: Number of inputs. + num_features: Dimensions of feature space. + + Arguments: + x (FloatTensor [N, num_features]): + Input feature representations. + + Returns: + y (FloatTensor [N, num_features]): + Output feature representations. + """ + for residual in self.stack: + x = residual(x) + return x + + +class SpookyNetResidualMLP(nn.Module): + def __init__( + self, + num_features: int, + num_residual: int, + bias: bool = True, + ) -> None: + super(SpookyNetResidualMLP, self).__init__() + self.residual = SpookyNetResidualStack( + num_features, num_residual, bias=bias + ) + self.activation = Swish(num_features) + self.linear = nn.Linear(num_features, num_features, bias=bias) + self.reset_parameters(bias) + + def reset_parameters(self, bias: bool = True) -> None: + nn.init.zeros_(self.linear.weight) + if bias: + nn.init.zeros_(self.linear.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(self.activation(self.residual(x))) + + +class SpookyNetLocalInteraction(nn.Module): + """ + Block for updating atomic features through local interactions with + neighboring atoms (message-passing). + + Arguments: + num_features (int): + Dimensions of feature space. + num_basis_functions (int): + Number of radial basis functions. + num_residual_x (int): + TODO + num_residual_s (int): + TODO + num_residual_p (int): + TODO + num_residual_d (int): + TODO + num_residual (int): + Number of residual blocks to be stacked in sequence. + """ + + def __init__( + self, + num_features: int, + num_basis_functions: int, + num_residual_x: int, + num_residual_s: int, + num_residual_p: int, + num_residual_d: int, + num_residual: int, + ) -> None: + """ Initializes the LocalInteraction class. """ + super(SpookyNetLocalInteraction, self).__init__() + self.radial_s = nn.Linear(num_basis_functions, num_features, bias=False) + self.radial_p = nn.Linear(num_basis_functions, num_features, bias=False) + self.radial_d = nn.Linear(num_basis_functions, num_features, bias=False) + self.resblock_x = SpookyNetResidualMLP(num_features, num_residual_x) + self.resblock_s = SpookyNetResidualMLP(num_features, num_residual_s) + self.resblock_p = SpookyNetResidualMLP(num_features, num_residual_p) + self.resblock_d = SpookyNetResidualMLP(num_features, num_residual_d) + self.projection_p = nn.Linear(num_features, 2 * num_features, bias=False) + self.projection_d = nn.Linear(num_features, 2 * num_features, bias=False) + self.resblock = SpookyNetResidualMLP( + num_features, num_residual + ) + self.reset_parameters() + + def reset_parameters(self) -> None: + """ Initialize parameters. """ + nn.init.orthogonal_(self.radial_s.weight) + nn.init.orthogonal_(self.radial_p.weight) + nn.init.orthogonal_(self.radial_d.weight) + nn.init.orthogonal_(self.projection_p.weight) + nn.init.orthogonal_(self.projection_d.weight) + + def forward( + self, + x: torch.Tensor, + rbf: torch.Tensor, + pij: torch.Tensor, + dij: torch.Tensor, + idx_i: torch.Tensor, + idx_j: torch.Tensor, + ) -> torch.Tensor: + """ + Evaluate interaction block. + N: Number of atoms. + P: Number of atom pairs. + + x (FloatTensor [N, num_features]): + Atomic feature vectors. + rbf (FloatTensor [N, num_basis_functions]): + Values of the radial basis functions for the pairwise distances. + idx_i (LongTensor [P]): + Index of atom i for all atomic pairs ij. Each pair must be + specified as both ij and ji. + idx_j (LongTensor [P]): + Same as idx_i, but for atom j. + """ + # interaction functions + gs = self.radial_s(rbf) + gp = self.radial_p(rbf).unsqueeze(-2) * pij.unsqueeze(-1) + gd = self.radial_d(rbf).unsqueeze(-2) * dij.unsqueeze(-1) + # atom featurizations + xx = self.resblock_x(x) + xs = self.resblock_s(x) + xp = self.resblock_p(x) + xd = self.resblock_d(x) + # collect neighbors + xs = xs[idx_j] # L=0 + xp = xp[idx_j] # L=1 + xd = xd[idx_j] # L=2 + # sum over neighbors + pp = x.new_zeros(x.shape[0], pij.shape[-1], x.shape[-1]) + dd = x.new_zeros(x.shape[0], dij.shape[-1], x.shape[-1]) + s = xx.index_add(0, idx_i, gs * xs) # L=0 + p = pp.index_add_(0, idx_i, gp * xp.unsqueeze(-2)) # L=1 + d = dd.index_add_(0, idx_i, gd * xd.unsqueeze(-2)) # L=2 + # project tensorial features to scalars + pa, pb = torch.split(self.projection_p(p), p.shape[-1], dim=-1) + da, db = torch.split(self.projection_d(d), d.shape[-1], dim=-1) + return self.resblock(s + (pa * pb).sum(-2) + (da * db).sum(-2)) + + +class SpookyNetAttention(nn.Module): + """ + Efficient (linear scaling) approximation for attention described in + Choromanski, K., et al. "Rethinking Attention with Performers.". + + Arguments: + dim_qk (int): + Dimension of query/key vectors. + num_random_features (int): + Number of random features for approximating attention matrix. If + this is 0, the exact attention matrix is computed. + """ + + def __init__( + self, dim_qk: int, num_random_features: int + ) -> None: + """ Initializes the Attention class. """ + super(SpookyNetAttention, self).__init__() + self.num_random_features = num_random_features + omega = self._omega(num_random_features, dim_qk) + self.register_buffer("omega", torch.tensor(omega, dtype=torch.float32)) + self.reset_parameters() + + def reset_parameters(self) -> None: + """ For compatibility with other modules. """ + pass + + def _omega(self, nrows: int, ncols: int) -> np.ndarray: + """ Return a (nrows x ncols) random feature matrix. """ + nblocks = int(nrows / ncols) + blocks = [] + for i in range(nblocks): + block = np.random.normal(size=(ncols, ncols)) + q, _ = np.linalg.qr(block) + blocks.append(np.transpose(q)) + missing_rows = nrows - nblocks * ncols + if missing_rows > 0: + block = np.random.normal(size=(ncols, ncols)) + q, _ = np.linalg.qr(block) + blocks.append(np.transpose(q)[:missing_rows]) + norm = np.linalg.norm( # renormalize rows so they still follow N(0,1) + np.random.normal(size=(nrows, ncols)), axis=1, keepdims=True + ) + return (norm * np.vstack(blocks)).T + + def _phi( + self, + X: torch.Tensor, + is_query: bool, + num_batch: int, + batch_seg: torch.Tensor, + eps: float = 1e-4, + ) -> torch.Tensor: + """ Normalize X and project into random feature space. """ + d = X.shape[-1] + m = self.omega.shape[-1] + U = torch.matmul(X / d ** 0.25, self.omega) + h = torch.sum(X ** 2, dim=-1, keepdim=True) / (2 * d ** 0.5) # OLD + # determine maximum (is subtracted to prevent numerical overflow) + if is_query: + maximum, _ = torch.max(U, dim=-1, keepdim=True) + else: + if num_batch > 1: + brow = batch_seg.view(1, -1, 1).expand(num_batch, -1, U.shape[-1]) + bcol = ( + torch.arange( + num_batch, dtype=batch_seg.dtype, device=batch_seg.device + ) + .view(-1, 1, 1) + .expand(-1, U.shape[-2], U.shape[-1]) + ) + mask = torch.where( + torch.eq(brow, bcol), torch.ones_like(U), torch.zeros_like(U) + ) + tmp = U.unsqueeze(0).expand(num_batch, -1, -1) + tmp, _ = torch.max(mask * tmp, dim=-1) + tmp, _ = torch.max(tmp, dim=-1) + maximum = tmp[batch_seg].unsqueeze(-1) + else: + maximum = torch.max(U) + return (torch.exp(U - h - maximum) + eps) / math.sqrt(m) + + def forward( + self, + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + num_batch: int, + batch_seg: torch.Tensor, + mask: Optional[torch.Tensor] = None, + eps: float = 1e-8, + ) -> torch.Tensor: + """ + Compute attention for the given query, key and value vectors. + N: Number of input values. + dim_qk: Dimension of query/key vectors. + dim_v: Dimension of value vectors. + + Arguments: + Q (FloatTensor [N, dim_qk]): + Matrix of N query vectors. + K (FloatTensor [N, dim_qk]): + Matrix of N key vectors. + V (FloatTensor [N, dim_v]): + Matrix of N value vectors. + num_batch (int): + Number of different batches in the input values. + batch_seg (LongTensor [N]): + Index for each input that specifies to which batch it belongs. + For example, when the input consists of a sequence of size 3 and + another sequence of size 5, batch_seg would be + [0, 0, 0, 1, 1, 1, 1, 1] (num_batch would be 2 then). + mask (Optional[FloatTensor [N, N]]): TODO: check shape + Mask to apply to the attention matrix. + eps (float): + Small constant to prevent numerical instability. + Returns: + y (FloatTensor [N, dim_v]): + Attention-weighted sum of value vectors. + """ + Q = self._phi(Q, True, num_batch, batch_seg) # random projection of Q + K = self._phi(K, False, num_batch, batch_seg) # random projection of K + if num_batch > 1: + d = Q.shape[-1] + + # compute norm + idx = batch_seg.unsqueeze(-1).expand(-1, d) + tmp = K.new_zeros(num_batch, d).scatter_add_(0, idx, K) + norm = torch.gather(Q @ tmp.T, -1, batch_seg.unsqueeze(-1)) + eps + + # the ops below are equivalent to this loop (but more efficient): + # return torch.cat([Q[b==batch_seg]@( + # K[b==batch_seg].transpose(-1,-2)@V[b==batch_seg]) + # for b in range(num_batch)])/norm + if mask is None: # mask can be shared across multiple attentions + one_hot = nn.functional.one_hot(batch_seg).to( + dtype=V.dtype, device=V.device + ) + mask = one_hot @ one_hot.transpose(-1, -2) + return ((mask * (K @ Q.transpose(-1, -2))).transpose(-1, -2) @ V) / norm + else: + norm = Q @ torch.sum(K, 0, keepdim=True).T + eps + return (Q @ (K.T @ V)) / norm + + +class SpookyNetNonlocalInteraction(nn.Module): + """ + Block for updating atomic features through nonlocal interactions with all + atoms. + + Arguments: + num_features (int): + Dimensions of feature space. + num_residual_q (int): + Number of residual blocks for queries. + num_residual_k (int): + Number of residual blocks for keys. + num_residual_v (int): + Number of residual blocks for values. + """ + + def __init__( + self, + num_features: int, + num_residual_q: int, + num_residual_k: int, + num_residual_v: int, + ) -> None: + """ Initializes the NonlocalInteraction class. """ + super(SpookyNetNonlocalInteraction, self).__init__() + self.resblock_q = SpookyNetResidualMLP( + num_features, num_residual_q + ) + self.resblock_k = SpookyNetResidualMLP( + num_features, num_residual_k + ) + self.resblock_v = SpookyNetResidualMLP( + num_features, num_residual_v + ) + self.attention = SpookyNetAttention(dim_qk=num_features, num_random_features=num_features) + self.reset_parameters() + + def reset_parameters(self) -> None: + """ For compatibility with other modules. """ + pass + + def forward( + self, + x: torch.Tensor, + num_batch: int, + batch_seg: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Evaluate interaction block. + N: Number of atoms. + + x (FloatTensor [N, num_features]): + Atomic feature vectors. + """ + q = self.resblock_q(x) # queries + k = self.resblock_k(x) # keys + v = self.resblock_v(x) # values + return self.attention(q, k, v, num_batch, batch_seg, mask) + + +class SpookyNetInteractionModule(nn.Module): + """ + InteractionModule of SpookyNet, which computes a single iteration. + + Arguments: + num_features (int): + Dimensions of feature space. + num_basis_functions (int): + Number of radial basis functions. + num_residual_pre (int): + Number of residual blocks applied to atomic features before + interaction with neighbouring atoms. + num_residual_local_x (int): + TODO + num_residual_local_s (int): + TODO + num_residual_local_p (int): + TODO + num_residual_local_d (int): + TODO + num_residual_local (int): + TODO + num_residual_nonlocal_q (int): + Number of residual blocks for queries in nonlocal interactions. + num_residual_nonlocal_k (int): + Number of residual blocks for keys in nonlocal interactions. + num_residual_nonlocal_v (int): + Number of residual blocks for values in nonlocal interactions. + num_residual_post (int): + Number of residual blocks applied to atomic features after + interaction with neighbouring atoms. + num_residual_output (int): + Number of residual blocks applied to atomic features in output + branch. + """ + + def __init__( + self, + num_features: int, + num_basis_functions: int, + num_residual_pre: int, + num_residual_local_x: int, + num_residual_local_s: int, + num_residual_local_p: int, + num_residual_local_d: int, + num_residual_local: int, + num_residual_nonlocal_q: int, + num_residual_nonlocal_k: int, + num_residual_nonlocal_v: int, + num_residual_post: int, + num_residual_output: int, + ) -> None: + """ Initializes the InteractionModule class. """ + super(SpookyNetInteractionModule, self).__init__() + # initialize modules + self.local_interaction = SpookyNetLocalInteraction( + num_features=num_features, + num_basis_functions=num_basis_functions, + num_residual_x=num_residual_local_x, + num_residual_s=num_residual_local_s, + num_residual_p=num_residual_local_p, + num_residual_d=num_residual_local_d, + num_residual=num_residual_local, + ) + self.nonlocal_interaction = SpookyNetNonlocalInteraction( + num_features=num_features, + num_residual_q=num_residual_nonlocal_q, + num_residual_k=num_residual_nonlocal_k, + num_residual_v=num_residual_nonlocal_v, + ) + + self.residual_pre = SpookyNetResidualStack(num_features, num_residual_pre) + self.residual_post = SpookyNetResidualStack(num_features, num_residual_post) + self.resblock = SpookyNetResidualMLP(num_features, num_residual_output) + self.reset_parameters() + + def reset_parameters(self) -> None: + """ For compatibility with other modules. """ + pass + + def forward( + self, + x: torch.Tensor, + rbf: torch.Tensor, + pij: torch.Tensor, + dij: torch.Tensor, + idx_i: torch.Tensor, + idx_j: torch.Tensor, + num_batch: int, + batch_seg: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Evaluate all modules in the block. + N: Number of atoms. + P: Number of atom pairs. + B: Batch size (number of different molecules). + + Arguments: + x (FloatTensor [N, num_features]): + Latent atomic feature vectors. + rbf (FloatTensor [P, num_basis_functions]): + Values of the radial basis functions for the pairwise distances. + pij (FloatTensor [P, 3]): + Unit vectors pointing from atom i to atom j for all atomic pairs. + dij (FloatTensor [P]): + Distances between atom i and atom j for all atomic pairs. + idx_i (LongTensor [P]): + Index of atom i for all atomic pairs ij. Each pair must be + specified as both ij and ji. + idx_j (LongTensor [P]): + Same as idx_i, but for atom j. + num_batch (int): + Batch size (number of different molecules). + batch_seg (LongTensor [N]): + Index for each atom that specifies to which molecule in the + batch it belongs. + mask (Optional[FloatTensor [B, N, N]]): TODO: check shape + Mask for attention mechanism to prevent interactions between + atoms of different molecules. + Returns: + x (FloatTensor [N, num_features]): + Updated latent atomic feature vectors. + y (FloatTensor [N, num_features]): + Contribution to output atomic features (environment + descriptors). + """ + x = self.residual_pre(x) + l = self.local_interaction(x, rbf, pij, dij, idx_i, idx_j) + n = self.nonlocal_interaction(x, num_batch, batch_seg, mask) + x = self.residual_post(x + l + n) + return x, self.resblock(x) diff --git a/modelforge/tests/test_spookynet.py b/modelforge/tests/test_spookynet.py index 9bcf9906..d7b1b49f 100644 --- a/modelforge/tests/test_spookynet.py +++ b/modelforge/tests/test_spookynet.py @@ -86,3 +86,35 @@ def test_atomic_properties_static(): Z = torch.randint(1, 100, (nr_atoms,)) R = torch.rand((nr_atoms, geometry_basis)) print(ref_spookynet._atomic_properties_static(Z, R, idx_i, idx_j)) + + +def test_spookynet_interaction_module_forward(): + from modelforge.potential.spookynet import SpookyNetInteractionModule + N = 5 + P = 19 + num_features = 7 + B = 23 + spookynet_interaction_module = SpookyNetInteractionModule( + num_features=num_features, + num_basis_functions=5, + num_residual_pre=3, + num_residual_local_x=3, + num_residual_local_s=3, + num_residual_local_p=3, + num_residual_local_d=3, + num_residual_local=3, + num_residual_nonlocal_q=11, + num_residual_nonlocal_k=13, + num_residual_nonlocal_v=17, + num_residual_post=3, + num_residual_output=3 + ) + + x = torch.rand((N, num_features)) + rbf = torch.rand((P, 5)) + pij = torch.rand((P, 1)) + dij = torch.rand((P, 1)) + idx_i, idx_j = make_random_pairlist(N, P, include_self_pairs=False) + batch_seg = torch.randint(0, B, (N,)) + mask = torch.rand((B, N, N)) + spookynet_interaction_module(x, rbf, pij, dij, idx_i, idx_j, B, batch_seg, mask) From fe5cd04cf707863f314f34072fdac1b70268543b Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Thu, 20 Jun 2024 08:02:00 -0700 Subject: [PATCH 04/78] Hard code one batch (relevant within attention). Rename variables to better match paper. --- modelforge/potential/spookynet.py | 228 +++++++++++++++-------------- modelforge/potential/utils.py | 33 +++++ modelforge/tests/test_spookynet.py | 4 +- 3 files changed, 155 insertions(+), 110 deletions(-) diff --git a/modelforge/potential/spookynet.py b/modelforge/potential/spookynet.py index cf48cf6a..a9b74581 100644 --- a/modelforge/potential/spookynet.py +++ b/modelforge/potential/spookynet.py @@ -192,18 +192,16 @@ def _forward(self, data: SpookyNetNeuralNetworkData) -> Dict[str, torch.Tensor]: # Compute the representation for each atom (transform to radial basis set, multiply by cutoff) representation = self.spookynet_representation_module(data.d_ij) - data.f_ij = representation["f_ij"] - data.f_cutoff = representation["f_cutoff"] + data.filters = representation["filters"] x = data.atomic_embedding + + f = x.new_zeros(x.size()) # initialize output features to zero # Iterate over interaction blocks to update features for interaction in self.interaction_modules: - v = interaction( - x, - data.pair_indices, - representation["f_ij"], - representation["f_cutoff"], + x, y = interaction( + x, rbf, pij, dij, sr_idx_i, sr_idx_j, num_batch, batch_seg, mask ) - x = x + v # Update atomic features + f += y # accumulate module output to features E_i = self.energy_layer(x).squeeze(1) @@ -277,6 +275,80 @@ def _config_prior(self): return prior +class SpookyNetRepresentation(nn.Module): + + def __init__( + self, + cutoff: unit = 5 * unit.angstrom, + number_of_radial_basis_functions: int = 16, + ): + """ + Representation module for the PhysNet potential, handling the generation of + the radial basis functions (RBFs) with a cutoff. + + Parameters + ---------- + cutoff : openff.units.unit.Quantity, default=5*unit.angstrom + The cutoff distance for interactions. + number_of_gaussians : int, default=16 + Number of Gaussian functions to use in the radial basis function. + """ + + super().__init__() + + # cutoff + from modelforge.potential import CosineCutoff + + + # radial symmetry function + from .utils import PhysNetRadialSymmetryFunction + + self.radial_symmetry_function_module = PhysNetRadialSymmetryFunction( + number_of_radial_basis_functions=number_of_radial_basis_functions, + max_distance=cutoff, + dtype=torch.float32, + ) + + def forward(self, d_ij: torch.Tensor, r_ij: torch.Tensor) -> dict[str, torch.Tensor]: + """ + Forward pass of the representation module. + + Parameters + ---------- + d_ij : torch.Tensor + pairwise distances between atoms, shape (n_pairs). + r_ij : torch.Tensor + pairwise displacements between atoms, shape (n_pairs, 3). + + Returns + ------- + torch.Tensor + The radial basis function expansion applied to the input distances, + shape (n_pairs, n_gaussians), after applying the cutoff function. + """ + + sqrt3 = math.sqrt(3) + sqrt3half = 0.5 * sqrt3 + # short-range distances + p_orbital_ij = r_ij / d_ij.unsqueeze(-1) + d_orbital_ij = torch.stack( + [ + sqrt3 * p_orbital_ij[:, 0] * p_orbital_ij[:, 1], # xy + sqrt3 * p_orbital_ij[:, 0] * p_orbital_ij[:, 2], # xz + sqrt3 * p_orbital_ij[:, 1] * p_orbital_ij[:, 2], # yz + 0.5 * (3 * p_orbital_ij[:, 2] * p_orbital_ij[:, 2] - 1.0), # z2 + sqrt3half + * (p_orbital_ij[:, 0] * p_orbital_ij[:, 0] - p_orbital_ij[:, 1] * p_orbital_ij[:, 1]), # x2-y2 + ], + dim=-1, + ) + f_ij = self.radial_symmetry_function_module(d_ij) + f_ij_cutoff = self.cutoff_module(d_ij) + filters = f_ij * f_ij_cutoff + + return {"filters": filters, "p_orbital_ij": p_orbital_ij, "d_orbital_ij": d_orbital_ij} + + class Swish(nn.Module): """ Swish activation function with learnable feature-wise parameters: @@ -390,14 +462,14 @@ class SpookyNetResidualStack(nn.Module): Arguments: num_features (int): Dimensions of feature space. - num_residual (int): + number_of_residual_blocks (int): Number of residual blocks to be stacked in sequence. """ def __init__( self, num_features: int, - num_residual: int, + number_of_residual_blocks: int, bias: bool = True, ) -> None: """ Initializes the ResidualStack class. """ @@ -405,7 +477,7 @@ def __init__( self.stack = nn.ModuleList( [ SpookyNetResidual(num_features, bias) - for i in range(num_residual) + for _ in range(number_of_residual_blocks) ] ) @@ -432,12 +504,12 @@ class SpookyNetResidualMLP(nn.Module): def __init__( self, num_features: int, - num_residual: int, + number_of_residual_blocks: int, bias: bool = True, ) -> None: super(SpookyNetResidualMLP, self).__init__() self.residual = SpookyNetResidualStack( - num_features, num_residual, bias=bias + num_features, number_of_residual_blocks, bias=bias ) self.activation = Swish(num_features) self.linear = nn.Linear(num_features, num_features, bias=bias) @@ -510,10 +582,10 @@ def reset_parameters(self) -> None: def forward( self, - x: torch.Tensor, + x_tilde: torch.Tensor, rbf: torch.Tensor, - pij: torch.Tensor, - dij: torch.Tensor, + p_orbital_ij: torch.Tensor, + d_orbital_ij: torch.Tensor, idx_i: torch.Tensor, idx_j: torch.Tensor, ) -> torch.Tensor: @@ -526,6 +598,10 @@ def forward( Atomic feature vectors. rbf (FloatTensor [N, num_basis_functions]): Values of the radial basis functions for the pairwise distances. + p_orbital_ij (TODO): + TODO + d_orbital_ij (TODO): + TODO idx_i (LongTensor [P]): Index of atom i for all atomic pairs ij. Each pair must be specified as both ij and ji. @@ -534,20 +610,20 @@ def forward( """ # interaction functions gs = self.radial_s(rbf) - gp = self.radial_p(rbf).unsqueeze(-2) * pij.unsqueeze(-1) - gd = self.radial_d(rbf).unsqueeze(-2) * dij.unsqueeze(-1) + gp = self.radial_p(rbf).unsqueeze(-2) * p_orbital_ij.unsqueeze(-1) + gd = self.radial_d(rbf).unsqueeze(-2) * d_orbital_ij.unsqueeze(-1) # atom featurizations - xx = self.resblock_x(x) - xs = self.resblock_s(x) - xp = self.resblock_p(x) - xd = self.resblock_d(x) + xx = self.resblock_x(x_tilde) + xs = self.resblock_s(x_tilde) + xp = self.resblock_p(x_tilde) + xd = self.resblock_d(x_tilde) # collect neighbors xs = xs[idx_j] # L=0 xp = xp[idx_j] # L=1 xd = xd[idx_j] # L=2 # sum over neighbors - pp = x.new_zeros(x.shape[0], pij.shape[-1], x.shape[-1]) - dd = x.new_zeros(x.shape[0], dij.shape[-1], x.shape[-1]) + pp = x_tilde.new_zeros(x_tilde.shape[0], p_orbital_ij.shape[-1], x_tilde.shape[-1]) + dd = x_tilde.new_zeros(x_tilde.shape[0], d_orbital_ij.shape[-1], x_tilde.shape[-1]) s = xx.index_add(0, idx_i, gs * xs) # L=0 p = pp.index_add_(0, idx_i, gp * xp.unsqueeze(-2)) # L=1 d = dd.index_add_(0, idx_i, gd * xd.unsqueeze(-2)) # L=2 @@ -606,8 +682,6 @@ def _phi( self, X: torch.Tensor, is_query: bool, - num_batch: int, - batch_seg: torch.Tensor, eps: float = 1e-4, ) -> torch.Tensor: """ Normalize X and project into random feature space. """ @@ -619,24 +693,7 @@ def _phi( if is_query: maximum, _ = torch.max(U, dim=-1, keepdim=True) else: - if num_batch > 1: - brow = batch_seg.view(1, -1, 1).expand(num_batch, -1, U.shape[-1]) - bcol = ( - torch.arange( - num_batch, dtype=batch_seg.dtype, device=batch_seg.device - ) - .view(-1, 1, 1) - .expand(-1, U.shape[-2], U.shape[-1]) - ) - mask = torch.where( - torch.eq(brow, bcol), torch.ones_like(U), torch.zeros_like(U) - ) - tmp = U.unsqueeze(0).expand(num_batch, -1, -1) - tmp, _ = torch.max(mask * tmp, dim=-1) - tmp, _ = torch.max(tmp, dim=-1) - maximum = tmp[batch_seg].unsqueeze(-1) - else: - maximum = torch.max(U) + maximum = torch.max(U) return (torch.exp(U - h - maximum) + eps) / math.sqrt(m) def forward( @@ -644,9 +701,6 @@ def forward( Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, - num_batch: int, - batch_seg: torch.Tensor, - mask: Optional[torch.Tensor] = None, eps: float = 1e-8, ) -> torch.Tensor: """ @@ -662,44 +716,16 @@ def forward( Matrix of N key vectors. V (FloatTensor [N, dim_v]): Matrix of N value vectors. - num_batch (int): - Number of different batches in the input values. - batch_seg (LongTensor [N]): - Index for each input that specifies to which batch it belongs. - For example, when the input consists of a sequence of size 3 and - another sequence of size 5, batch_seg would be - [0, 0, 0, 1, 1, 1, 1, 1] (num_batch would be 2 then). - mask (Optional[FloatTensor [N, N]]): TODO: check shape - Mask to apply to the attention matrix. eps (float): Small constant to prevent numerical instability. Returns: y (FloatTensor [N, dim_v]): Attention-weighted sum of value vectors. """ - Q = self._phi(Q, True, num_batch, batch_seg) # random projection of Q - K = self._phi(K, False, num_batch, batch_seg) # random projection of K - if num_batch > 1: - d = Q.shape[-1] - - # compute norm - idx = batch_seg.unsqueeze(-1).expand(-1, d) - tmp = K.new_zeros(num_batch, d).scatter_add_(0, idx, K) - norm = torch.gather(Q @ tmp.T, -1, batch_seg.unsqueeze(-1)) + eps - - # the ops below are equivalent to this loop (but more efficient): - # return torch.cat([Q[b==batch_seg]@( - # K[b==batch_seg].transpose(-1,-2)@V[b==batch_seg]) - # for b in range(num_batch)])/norm - if mask is None: # mask can be shared across multiple attentions - one_hot = nn.functional.one_hot(batch_seg).to( - dtype=V.dtype, device=V.device - ) - mask = one_hot @ one_hot.transpose(-1, -2) - return ((mask * (K @ Q.transpose(-1, -2))).transpose(-1, -2) @ V) / norm - else: - norm = Q @ torch.sum(K, 0, keepdim=True).T + eps - return (Q @ (K.T @ V)) / norm + Q = self._phi(Q, True) # random projection of Q + K = self._phi(K, False) # random projection of K + norm = Q @ torch.sum(K, 0, keepdim=True).T + eps + return (Q @ (K.T @ V)) / norm class SpookyNetNonlocalInteraction(nn.Module): @@ -745,10 +771,7 @@ def reset_parameters(self) -> None: def forward( self, - x: torch.Tensor, - num_batch: int, - batch_seg: torch.Tensor, - mask: Optional[torch.Tensor] = None, + x_tilde: torch.Tensor, ) -> torch.Tensor: """ Evaluate interaction block. @@ -757,10 +780,10 @@ def forward( x (FloatTensor [N, num_features]): Atomic feature vectors. """ - q = self.resblock_q(x) # queries - k = self.resblock_k(x) # keys - v = self.resblock_v(x) # values - return self.attention(q, k, v, num_batch, batch_seg, mask) + q = self.resblock_q(x_tilde) # queries + k = self.resblock_k(x_tilde) # keys + v = self.resblock_v(x_tilde) # values + return self.attention(q, k, v) class SpookyNetInteractionModule(nn.Module): @@ -847,13 +870,10 @@ def forward( self, x: torch.Tensor, rbf: torch.Tensor, - pij: torch.Tensor, - dij: torch.Tensor, + p_orbital_ij: torch.Tensor, + d_orbital_ij: torch.Tensor, idx_i: torch.Tensor, idx_j: torch.Tensor, - num_batch: int, - batch_seg: torch.Tensor, - mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Evaluate all modules in the block. @@ -866,23 +886,15 @@ def forward( Latent atomic feature vectors. rbf (FloatTensor [P, num_basis_functions]): Values of the radial basis functions for the pairwise distances. - pij (FloatTensor [P, 3]): + p_orbital_ij (FloatTensor [P, 3]): Unit vectors pointing from atom i to atom j for all atomic pairs. - dij (FloatTensor [P]): + d_orbital_ij (FloatTensor [P]): Distances between atom i and atom j for all atomic pairs. idx_i (LongTensor [P]): Index of atom i for all atomic pairs ij. Each pair must be specified as both ij and ji. idx_j (LongTensor [P]): Same as idx_i, but for atom j. - num_batch (int): - Batch size (number of different molecules). - batch_seg (LongTensor [N]): - Index for each atom that specifies to which molecule in the - batch it belongs. - mask (Optional[FloatTensor [B, N, N]]): TODO: check shape - Mask for attention mechanism to prevent interactions between - atoms of different molecules. Returns: x (FloatTensor [N, num_features]): Updated latent atomic feature vectors. @@ -890,8 +902,10 @@ def forward( Contribution to output atomic features (environment descriptors). """ - x = self.residual_pre(x) - l = self.local_interaction(x, rbf, pij, dij, idx_i, idx_j) - n = self.nonlocal_interaction(x, num_batch, batch_seg, mask) - x = self.residual_post(x + l + n) - return x, self.resblock(x) + x_tilde = self.residual_pre(x) + del x + l = self.local_interaction(x_tilde, rbf, p_orbital_ij, d_orbital_ij, idx_i, idx_j) + n = self.nonlocal_interaction(x_tilde) + x_updated = self.residual_post(x_tilde + l + n) + del x_tilde + return x_updated, self.resblock(x_updated) diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index e7c37b02..b9cc5453 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -274,6 +274,39 @@ def forward(self, d_ij: torch.Tensor): return input_cut +class SpookyNetCutoff(nn.Module): + """ + Implements Eq. 16 from + Unke, O.T., Chmiela, S., Gastegger, M. et al. SpookyNet: Learning force fields with + electronic degrees of freedom and nonlocal effects. Nat Commun 12, 7273 (2021). + Adapted from https://github.com/OUnke/SpookyNet/blob/d57b1fc02c4f1304a9445b2b9aa55a906818dd1b/spookynet/functional.py#L19 # noqa + """ + def __init__(self, cutoff: unit.Quantity): + """ + + Parameters: + ---------- + cutoff: unit.Quantity + The cutoff distance. + + """ + super().__init__() + cutoff = cutoff.to(unit.nanometer).m + self.register_buffer("cutoff", torch.tensor([cutoff])) + + def forward(self, d_ij: torch.Tensor): + """ + Cutoff function that smoothly goes from f(x) = 1 to f(x) = 0 in the interval + from x = 0 to x = cutoff. For x >= cutoff, f(x) = 0. This function has + infinitely many smooth derivatives. Only positive x should be used as input. + """ + zeros = torch.zeros_like(d_ij) + x_ = torch.where(d_ij < self.cutoff, d_ij, zeros) # prevent nan in backprop + return torch.where( + d_ij < self.cutoff, torch.exp(-(x_ ** 2) / ((self.cutoff - x_) * (self.cutoff + x_))), zeros + ) + + from typing import Dict diff --git a/modelforge/tests/test_spookynet.py b/modelforge/tests/test_spookynet.py index d7b1b49f..ca9d41cd 100644 --- a/modelforge/tests/test_spookynet.py +++ b/modelforge/tests/test_spookynet.py @@ -115,6 +115,4 @@ def test_spookynet_interaction_module_forward(): pij = torch.rand((P, 1)) dij = torch.rand((P, 1)) idx_i, idx_j = make_random_pairlist(N, P, include_self_pairs=False) - batch_seg = torch.randint(0, B, (N,)) - mask = torch.rand((B, N, N)) - spookynet_interaction_module(x, rbf, pij, dij, idx_i, idx_j, B, batch_seg, mask) + spookynet_interaction_module(x, rbf, pij, dij, idx_i, idx_j) From 2747fba58fa0b001c2ccf721a0d42136a00bf5fa Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Mon, 24 Jun 2024 23:39:49 -0700 Subject: [PATCH 05/78] More changes. Starting to implement tests --- modelforge/potential/spookynet.py | 36 +++++++------ modelforge/potential/utils.py | 10 ++-- modelforge/tests/test_spookynet.py | 81 ++++++++++++++++++++++++++---- 3 files changed, 97 insertions(+), 30 deletions(-) diff --git a/modelforge/potential/spookynet.py b/modelforge/potential/spookynet.py index a9b74581..e48b5ca6 100644 --- a/modelforge/potential/spookynet.py +++ b/modelforge/potential/spookynet.py @@ -125,6 +125,9 @@ def __init__( self.embedding_module = Embedding(max_Z, number_of_atom_features) + # initialize representation block + self.spookynet_representation_block = SpookyNetRepresentation(cutoff, number_of_radial_basis_functions) + # initialize the energy readout from .processing import FromAtomToMoleculeReduction @@ -192,14 +195,18 @@ def _forward(self, data: SpookyNetNeuralNetworkData) -> Dict[str, torch.Tensor]: # Compute the representation for each atom (transform to radial basis set, multiply by cutoff) representation = self.spookynet_representation_module(data.d_ij) - data.filters = representation["filters"] x = data.atomic_embedding f = x.new_zeros(x.size()) # initialize output features to zero # Iterate over interaction blocks to update features for interaction in self.interaction_modules: x, y = interaction( - x, rbf, pij, dij, sr_idx_i, sr_idx_j, num_batch, batch_seg, mask + x, + data.pair_indices, + representation["f_ij"], + representation["f_cutoff"], + representation["p_orbital_ij"], + representation["d_orbital_ij"] ) f += y # accumulate module output to features @@ -290,8 +297,8 @@ def __init__( ---------- cutoff : openff.units.unit.Quantity, default=5*unit.angstrom The cutoff distance for interactions. - number_of_gaussians : int, default=16 - Number of Gaussian functions to use in the radial basis function. + number_of_radial_basis_functions : int, default=16 + Number of radial basis functions """ super().__init__() @@ -583,7 +590,7 @@ def reset_parameters(self) -> None: def forward( self, x_tilde: torch.Tensor, - rbf: torch.Tensor, + f_ij_after_cutoff: torch.Tensor, p_orbital_ij: torch.Tensor, d_orbital_ij: torch.Tensor, idx_i: torch.Tensor, @@ -609,9 +616,9 @@ def forward( Same as idx_i, but for atom j. """ # interaction functions - gs = self.radial_s(rbf) - gp = self.radial_p(rbf).unsqueeze(-2) * p_orbital_ij.unsqueeze(-1) - gd = self.radial_d(rbf).unsqueeze(-2) * d_orbital_ij.unsqueeze(-1) + gs = self.radial_s(f_ij_after_cutoff) + gp = self.radial_p(f_ij_after_cutoff).unsqueeze(-2) * p_orbital_ij.unsqueeze(-1) + gd = self.radial_d(f_ij_after_cutoff).unsqueeze(-2) * d_orbital_ij.unsqueeze(-1) # atom featurizations xx = self.resblock_x(x_tilde) xs = self.resblock_s(x_tilde) @@ -869,11 +876,11 @@ def reset_parameters(self) -> None: def forward( self, x: torch.Tensor, - rbf: torch.Tensor, - p_orbital_ij: torch.Tensor, - d_orbital_ij: torch.Tensor, - idx_i: torch.Tensor, - idx_j: torch.Tensor, + pairlist: torch.Tensor, # shape [n_pairs, 2] + f_ij: torch.Tensor, # shape [n_pairs, 1, number_of_radial_basis_functions] TODO: why the 1? + f_ij_cutoff: torch.Tensor, # shape [n_pairs, 1] + p_orbital_ij: torch.Tensor, # shape [n_pairs, 1] + d_orbital_ij: torch.Tensor, # shape [n_pairs, 1] ) -> Tuple[torch.Tensor, torch.Tensor]: """ Evaluate all modules in the block. @@ -902,9 +909,10 @@ def forward( Contribution to output atomic features (environment descriptors). """ + idx_i, idx_j = pairlist[0], pairlist[1] x_tilde = self.residual_pre(x) del x - l = self.local_interaction(x_tilde, rbf, p_orbital_ij, d_orbital_ij, idx_i, idx_j) + l = self.local_interaction(x_tilde, f_ij * f_ij_cutoff, p_orbital_ij, d_orbital_ij, idx_i, idx_j) n = self.nonlocal_interaction(x_tilde) x_updated = self.residual_post(x_tilde + l + n) del x_tilde diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index b9cc5453..4571dd95 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -296,14 +296,14 @@ def __init__(self, cutoff: unit.Quantity): def forward(self, d_ij: torch.Tensor): """ - Cutoff function that smoothly goes from f(x) = 1 to f(x) = 0 in the interval - from x = 0 to x = cutoff. For x >= cutoff, f(x) = 0. This function has - infinitely many smooth derivatives. Only positive x should be used as input. + Cutoff function that smoothly goes from f(r) = 1 to f(r) = 0 in the interval + from r = 0 to r = cutoff. For r >= cutoff, f(r) = 0. This function has + infinitely many smooth derivatives. Only positive r should be used as input. """ zeros = torch.zeros_like(d_ij) - x_ = torch.where(d_ij < self.cutoff, d_ij, zeros) # prevent nan in backprop + r_ = torch.where(d_ij < self.cutoff, d_ij, zeros) # prevent nan in backprop return torch.where( - d_ij < self.cutoff, torch.exp(-(x_ ** 2) / ((self.cutoff - x_) * (self.cutoff + x_))), zeros + d_ij < self.cutoff, torch.exp(-(r_ ** 2) / ((self.cutoff - r_) * (self.cutoff + r_))), zeros ) diff --git a/modelforge/tests/test_spookynet.py b/modelforge/tests/test_spookynet.py index ca9d41cd..b135e7e0 100644 --- a/modelforge/tests/test_spookynet.py +++ b/modelforge/tests/test_spookynet.py @@ -54,13 +54,15 @@ def test_spookynet_forward(single_batch_with_batchsize_64, model_parameter): def make_random_pairlist(nr_atoms, nr_pairs, include_self_pairs): if include_self_pairs: + assert nr_pairs <= nr_atoms ** 2, """Number of pairs requested is more than the number of possible pairs.""" nr_pairs_choose = nr_pairs - nr_atoms - assert nr_pairs_choose >= 0, """Number of pairs must be greater than or equal to the number of atoms if " - include_self_pairs is True.""" else: + assert nr_pairs <= nr_atoms ** 2 - nr_atoms, """Number of pairs requested is more than the number of possible + pairs.""" nr_pairs_choose = nr_pairs - + assert nr_pairs_choose >= 0, """Number of pairs must be greater than or equal to the number of atoms if " + include_self_pairs is True or greater than 0 if include_self_pairs is False.""" all_pairs = torch.cartesian_prod(torch.arange(nr_atoms), torch.arange(nr_atoms)) self_pairs = all_pairs.T[0] == all_pairs.T[1] non_self_pairs = all_pairs[~self_pairs] @@ -93,17 +95,17 @@ def test_spookynet_interaction_module_forward(): N = 5 P = 19 num_features = 7 - B = 23 + number_of_radial_basis_functions = 5 spookynet_interaction_module = SpookyNetInteractionModule( num_features=num_features, - num_basis_functions=5, + num_basis_functions=number_of_radial_basis_functions, num_residual_pre=3, num_residual_local_x=3, num_residual_local_s=3, num_residual_local_p=3, num_residual_local_d=3, num_residual_local=3, - num_residual_nonlocal_q=11, + num_residual_nonlocal_q=19, num_residual_nonlocal_k=13, num_residual_nonlocal_v=17, num_residual_post=3, @@ -111,8 +113,65 @@ def test_spookynet_interaction_module_forward(): ) x = torch.rand((N, num_features)) - rbf = torch.rand((P, 5)) - pij = torch.rand((P, 1)) - dij = torch.rand((P, 1)) - idx_i, idx_j = make_random_pairlist(N, P, include_self_pairs=False) - spookynet_interaction_module(x, rbf, pij, dij, idx_i, idx_j) + f_ij = torch.rand((P, number_of_radial_basis_functions)) + f_ij_cutoff = torch.rand((P, 1)) + p_orbital_ij = torch.rand((P, 1)) + d_orbital_ij = torch.rand((P, 1)) + pairlist = make_random_pairlist(N, P, include_self_pairs=False) + spookynet_interaction_module(x, pairlist, f_ij, f_ij_cutoff, p_orbital_ij, d_orbital_ij) + +def test_spookynet_interaction_module_against_reference(): + from modelforge.potential.spookynet import SpookyNetInteractionModule as MfSpookyNetInteractionModule + from spookynet.modules.interaction_module import InteractionModule as RefSpookyNetInteractionModule + N = 5 + P = 19 + num_features = 7 + number_of_radial_basis_functions = 5 + num_residual_all = 3 + mf_spookynet_interaction_module = MfSpookyNetInteractionModule( + num_features=num_features, + num_basis_functions=number_of_radial_basis_functions, + num_residual_pre=num_residual_all, + num_residual_local_x=num_residual_all, + num_residual_local_s=num_residual_all, + num_residual_local_p=num_residual_all, + num_residual_local_d=num_residual_all, + num_residual_local=num_residual_all, + num_residual_nonlocal_q=num_residual_all, + num_residual_nonlocal_k=num_residual_all, + num_residual_nonlocal_v=num_residual_all, + num_residual_post=num_residual_all, + num_residual_output=num_residual_all + ).to(torch.double) + + ref_spookynet_interaction_module = RefSpookyNetInteractionModule( + num_features=num_features, + num_basis_functions=number_of_radial_basis_functions, + num_residual_pre=num_residual_all, + num_residual_local_x=num_residual_all, + num_residual_local_s=num_residual_all, + num_residual_local_p=num_residual_all, + num_residual_local_d=num_residual_all, + num_residual_local=num_residual_all, + num_residual_nonlocal_q=num_residual_all, + num_residual_nonlocal_k=num_residual_all, + num_residual_nonlocal_v=num_residual_all, + num_residual_post=num_residual_all, + num_residual_output=num_residual_all + ).to(torch.double) + + for model in [mf_spookynet_interaction_module, ref_spookynet_interaction_module]: + for name, param in model.named_parameters(): + print(name, param.size()) + + + x = torch.rand((N, num_features), dtype=torch.double) + f_ij = torch.rand((P, number_of_radial_basis_functions), dtype=torch.double) + f_ij_cutoff = torch.rand((P, 1), dtype=torch.double) + p_orbital_ij = torch.rand((P, 1), dtype=torch.double) + d_orbital_ij = torch.rand((P, 1), dtype=torch.double) + pairlist = make_random_pairlist(N, P, include_self_pairs=False) + idx_i, idx_j = pairlist + mf_spookynet_interaction_result = mf_spookynet_interaction_module(x, pairlist, f_ij, f_ij_cutoff, p_orbital_ij, d_orbital_ij) + ref_spookynet_interaction_result = ref_spookynet_interaction_module(x, f_ij * f_ij_cutoff, p_orbital_ij, d_orbital_ij, idx_i, idx_j, 1, None, None) + From 4a5e072f51258e3d754bd393c3a0968e4584065e Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Tue, 25 Jun 2024 22:49:19 -0700 Subject: [PATCH 06/78] Implement equivalence test for SpookyNet interaction block. First element of tuple return value matches, but second element (after final resblock) does not. --- .gitignore | 71 +++++++++++++- modelforge/potential/spookynet.py | 1 - modelforge/tests/test_spookynet.py | 27 ++++-- notebooks/benchmark.ipynb | 143 +++++++++++++++++++++++------ 4 files changed, 203 insertions(+), 39 deletions(-) diff --git a/.gitignore b/.gitignore index 8a753aee..af000ea0 100644 --- a/.gitignore +++ b/.gitignore @@ -189,4 +189,73 @@ lightning_logs/ *.gz *.npz *.hdf5 -scripts/tb_logs/* \ No newline at end of file +scripts/tb_logs/* +/.idea/.gitignore +/modelforge/tests/dataset_test/data_test/ani1x_dataset_v0_nc_1000_processed.json +/modelforge/tests/ani1x_dataset_v0_nc_1000_processed.json +/modelforge/tests/dataset_test/data_test/ani2x_dataset_v0_nc_1000_processed.json +/modelforge/tests/ani2x_dataset_v0_nc_1000_processed.json +/modelforge/tests/ase.toml +/modelforge/tests/dataset_statistics.pt +/modelforge/tests/tb_logs/training/version_0/checkpoints/epoch=1-step=4.ckpt +/modelforge/tests/tb_logs/training/version_1/checkpoints/epoch=1-step=4.ckpt +/modelforge/tests/tb_logs/training/version_10/checkpoints/epoch=1-step=4.ckpt +/modelforge/tests/tb_logs/training/version_11/checkpoints/epoch=1-step=4.ckpt +/modelforge/tests/tb_logs/training/version_0/events.out.tfevents.1717786428.Arnav-HP-Pavilion.28563.0 +/modelforge/tests/tb_logs/training/version_1/events.out.tfevents.1717786434.Arnav-HP-Pavilion.28563.1 +/modelforge/tests/tb_logs/training/version_2/events.out.tfevents.1717786438.Arnav-HP-Pavilion.28563.2 +/modelforge/tests/tb_logs/training/version_3/events.out.tfevents.1717786440.Arnav-HP-Pavilion.28563.3 +/modelforge/tests/tb_logs/training/version_4/events.out.tfevents.1717786442.Arnav-HP-Pavilion.28563.4 +/modelforge/tests/tb_logs/training/version_5/events.out.tfevents.1717786444.Arnav-HP-Pavilion.28563.5 +/modelforge/tests/tb_logs/training/version_6/events.out.tfevents.1717786445.Arnav-HP-Pavilion.28563.6 +/modelforge/tests/tb_logs/training/version_7/events.out.tfevents.1717786447.Arnav-HP-Pavilion.28563.7 +/modelforge/tests/tb_logs/training/version_8/events.out.tfevents.1717786449.Arnav-HP-Pavilion.28563.8 +/modelforge/tests/tb_logs/training/version_9/events.out.tfevents.1717786450.Arnav-HP-Pavilion.28563.9 +/modelforge/tests/tb_logs/training/version_10/events.out.tfevents.1717796413.Arnav-HP-Pavilion.40083.0 +/modelforge/tests/tb_logs/training/version_11/events.out.tfevents.1717796420.Arnav-HP-Pavilion.40083.1 +/modelforge/tests/tb_logs/training/version_12/events.out.tfevents.1717796424.Arnav-HP-Pavilion.40083.2 +/modelforge/tests/tb_logs/training/version_13/events.out.tfevents.1717796427.Arnav-HP-Pavilion.40083.3 +/modelforge/tests/tb_logs/training/version_14/events.out.tfevents.1717796429.Arnav-HP-Pavilion.40083.4 +/modelforge/tests/tb_logs/training/version_15/events.out.tfevents.1717796432.Arnav-HP-Pavilion.40083.5 +/modelforge/tests/tb_logs/training/version_16/events.out.tfevents.1717796434.Arnav-HP-Pavilion.40083.6 +/modelforge/tests/tb_logs/training/version_17/events.out.tfevents.1717796435.Arnav-HP-Pavilion.40083.7 +/modelforge/tests/tb_logs/training/version_18/events.out.tfevents.1717796437.Arnav-HP-Pavilion.40083.8 +/modelforge/tests/tb_logs/training/version_19/events.out.tfevents.1717796439.Arnav-HP-Pavilion.40083.9 +/modelforge/tests/tb_logs/training/version_0/hparams.yaml +/modelforge/tests/tb_logs/training/version_1/hparams.yaml +/modelforge/tests/tb_logs/training/version_2/hparams.yaml +/modelforge/tests/tb_logs/training/version_3/hparams.yaml +/modelforge/tests/tb_logs/training/version_4/hparams.yaml +/modelforge/tests/tb_logs/training/version_6/hparams.yaml +/modelforge/tests/tb_logs/training/version_7/hparams.yaml +/modelforge/tests/tb_logs/training/version_8/hparams.yaml +/modelforge/tests/tb_logs/training/version_9/hparams.yaml +/modelforge/tests/tb_logs/training/version_10/hparams.yaml +/modelforge/tests/tb_logs/training/version_11/hparams.yaml +/modelforge/tests/tb_logs/training/version_12/hparams.yaml +/modelforge/tests/tb_logs/training/version_13/hparams.yaml +/modelforge/tests/tb_logs/training/version_14/hparams.yaml +/modelforge/tests/tb_logs/training/version_16/hparams.yaml +/modelforge/tests/tb_logs/training/version_18/hparams.yaml +/modelforge/tests/tb_logs/training/version_19/hparams.yaml +/.idea/misc.xml +/modelforge/tests/model.pth +/.idea/modelforge.iml +/.idea/modules.xml +/modelforge/tests/dataset_test/data_test/PhAlkEthOH_dataset_v0_nc_1000_processed.json +/modelforge/tests/PhAlkEthOH_dataset_v0_nc_1000_processed.json +/.idea/inspectionProfiles/profiles_settings.xml +/modelforge/tests/dataset_test/data_test/qm9_dataset_v0_nc_1000_processed.json +/modelforge/tests/dataset_test/test_diff_scenario/qm9_dataset_v0_nc_1000_processed.json +/modelforge/tests/qm9_dataset_v0_nc_1000_processed.json +/modelforge/tests/dataset_test/data_test/qm9_test.json +/modelforge/tests/dataset_test/data_test/SPICE2_dataset_v0_nc_1000_processed.json +/modelforge/tests/SPICE2_dataset_v0_nc_1000_processed.json +/modelforge/tests/dataset_test/data_test/SPICE114_dataset_nc_1000_processed.json +/modelforge/tests/SPICE114_dataset_nc_1000_processed.json +/modelforge/tests/dataset_test/data_test/SPICE114OpenFF_dataset_nc_1000_processed.json +/modelforge/tests/SPICE114OpenFF_dataset_nc_1000_processed.json +/modelforge/tests/test.chp +/modelforge/tests/torch_dataset.pt +/.idea/vcs.xml +.gitignore diff --git a/modelforge/potential/spookynet.py b/modelforge/potential/spookynet.py index e48b5ca6..0f084b99 100644 --- a/modelforge/potential/spookynet.py +++ b/modelforge/potential/spookynet.py @@ -95,7 +95,6 @@ def __init__( number_of_radial_basis_functions: int = 20, number_of_interaction_modules: int = 3, number_of_filters: int = 64, - shared_interactions: bool = False, cutoff: unit.Quantity = 5.0 * unit.angstrom, ) -> None: """ diff --git a/modelforge/tests/test_spookynet.py b/modelforge/tests/test_spookynet.py index b135e7e0..19370ec5 100644 --- a/modelforge/tests/test_spookynet.py +++ b/modelforge/tests/test_spookynet.py @@ -120,6 +120,7 @@ def test_spookynet_interaction_module_forward(): pairlist = make_random_pairlist(N, P, include_self_pairs=False) spookynet_interaction_module(x, pairlist, f_ij, f_ij_cutoff, p_orbital_ij, d_orbital_ij) + def test_spookynet_interaction_module_against_reference(): from modelforge.potential.spookynet import SpookyNetInteractionModule as MfSpookyNetInteractionModule from spookynet.modules.interaction_module import InteractionModule as RefSpookyNetInteractionModule @@ -144,7 +145,7 @@ def test_spookynet_interaction_module_against_reference(): num_residual_output=num_residual_all ).to(torch.double) - ref_spookynet_interaction_module = RefSpookyNetInteractionModule( + ref_spookynet_interaction_module = RefSpookyNetInteractionModule( num_features=num_features, num_basis_functions=number_of_radial_basis_functions, num_residual_pre=num_residual_all, @@ -160,9 +161,18 @@ def test_spookynet_interaction_module_against_reference(): num_residual_output=num_residual_all ).to(torch.double) - for model in [mf_spookynet_interaction_module, ref_spookynet_interaction_module]: - for name, param in model.named_parameters(): - print(name, param.size()) + for (_, mf_param), (_, ref_param) in zip(mf_spookynet_interaction_module.named_parameters(), + ref_spookynet_interaction_module.named_parameters()): + mf_param.requires_grad = False + mf_param[:] = ref_param + + assert len(list(mf_spookynet_interaction_module.resblock.named_parameters())) == len(list(ref_spookynet_interaction_module.resblock.named_parameters())) + for (mf_name, mf_param), (ref_name, ref_param) in zip(mf_spookynet_interaction_module.resblock.named_parameters(), ref_spookynet_interaction_module.resblock.named_parameters()): + print(f"{mf_name=} {ref_name=}") + if not torch.equal(mf_param, ref_param): + print(f"{mf_param=} {ref_param=}") + else: + print("parameters are the same") x = torch.rand((N, num_features), dtype=torch.double) @@ -172,6 +182,9 @@ def test_spookynet_interaction_module_against_reference(): d_orbital_ij = torch.rand((P, 1), dtype=torch.double) pairlist = make_random_pairlist(N, P, include_self_pairs=False) idx_i, idx_j = pairlist - mf_spookynet_interaction_result = mf_spookynet_interaction_module(x, pairlist, f_ij, f_ij_cutoff, p_orbital_ij, d_orbital_ij) - ref_spookynet_interaction_result = ref_spookynet_interaction_module(x, f_ij * f_ij_cutoff, p_orbital_ij, d_orbital_ij, idx_i, idx_j, 1, None, None) - + mf_x_result, mf_y_result = mf_spookynet_interaction_module(x, pairlist, f_ij, f_ij_cutoff, p_orbital_ij, + d_orbital_ij) + ref_x_result, ref_y_result = ref_spookynet_interaction_module(x, f_ij * f_ij_cutoff, p_orbital_ij, + d_orbital_ij, idx_i, idx_j, 1, None, None) + assert torch.equal(mf_x_result, ref_x_result) + assert torch.equal(mf_y_result, ref_y_result) diff --git a/notebooks/benchmark.ipynb b/notebooks/benchmark.ipynb index cc9b9b1c..b3a9e22c 100644 --- a/notebooks/benchmark.ipynb +++ b/notebooks/benchmark.ipynb @@ -2,22 +2,32 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-12T21:11:54.333656Z", + "start_time": "2024-06-12T21:11:54.330838Z" + } + }, "source": [ "# Following this guid to benchmark PyTorch operations: https://pytorch.org/tutorials/recipes/recipes/benchmark.html#benchmarking-with-torch-utils-benchmark-timer\n", "\n", "import torch\n", "import torch.utils.benchmark as benchmark\n" - ] + ], + "outputs": [], + "execution_count": 65 }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-12T21:11:56.253655Z", + "start_time": "2024-06-12T21:11:56.247676Z" + } + }, "source": [ + "num_selections = 10\n", + "\n", "# define the functions to compare/benchmark/time\n", "def index_using_gather(tensor, indices):\n", " \"\"\"Selects elements from a tensor using gather (for N, 1).\"\"\"\n", @@ -27,35 +37,51 @@ " \"\"\"Selects elements from a tensor using integer indexing (for N, 1).\"\"\"\n", " return tensor[indices] # Direct indexing on the first dimension\n", "\n", + "def index_using_index_select(tensor, indices):\n", + " return torch.index_select(tensor, 0, indices)\n", + "\n", "# Sample tensor and indices\n", "tensor = torch.randn(1000, 1)\n", - "indices = torch.randint(0, tensor.shape[0], (100, )) # Generate random indices for N\n" - ] + "indices = torch.randint(0, tensor.shape[0], (num_selections, )) # Generate random indices for N\n" + ], + "outputs": [], + "execution_count": 66 }, { + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-12T21:11:58.788683Z", + "start_time": "2024-06-12T21:11:58.774244Z" + } + }, "cell_type": "code", - "execution_count": 4, - "metadata": {}, + "source": [ + "tensor = tensor.to(\"cuda\")\n", + "indices = indices.to(\"cuda\")" + ], "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Gather:\n", - "\n", - "index_using_gather(tensor.clone(), indices.clone())\n", - "setup: from __main__ import index_using_gather, tensor, indices\n", - " 6.72 us\n", - " 1 measurement, 1000 runs , 1 thread\n", - "Integer Indexing:\n", - "\n", - "index_using_integral_indexing(tensor.clone(), indices.clone())\n", - "setup: from __main__ import index_using_integral_indexing, tensor, indices\n", - " 6.89 us\n", - " 1 measurement, 1000 runs , 1 thread\n" + "ename": "RuntimeError", + "evalue": "CUDA error: unspecified launch failure\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n", + "output_type": "error", + "traceback": [ + "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[0;31mRuntimeError\u001B[0m Traceback (most recent call last)", + "Cell \u001B[0;32mIn[67], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m tensor \u001B[38;5;241m=\u001B[39m \u001B[43mtensor\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mto\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mcuda\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m)\u001B[49m\n\u001B[1;32m 2\u001B[0m indices \u001B[38;5;241m=\u001B[39m indices\u001B[38;5;241m.\u001B[39mto(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mcuda\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n", + "\u001B[0;31mRuntimeError\u001B[0m: CUDA error: unspecified launch failure\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n" ] } ], + "execution_count": 67 + }, + { + "cell_type": "code", + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-12T21:12:01.456047Z", + "start_time": "2024-06-12T21:12:01.373567Z" + } + }, "source": [ "# Benchmarking with pytorch.utils.benchmark\n", "t_gather = benchmark.Timer(\n", @@ -66,16 +92,73 @@ " stmt=\"index_using_integral_indexing(tensor.clone(), indices.clone())\",\n", " setup=\"from __main__ import index_using_integral_indexing, tensor, indices\",\n", ")\n", + "t_index_select = benchmark.Timer(\n", + " stmt=\"index_using_index_select(tensor.clone(), indices.clone())\",\n", + " setup=\"from __main__ import index_using_index_select, tensor, indices\",\n", + ")\n", "\n", "# Repeatedly run the timers for more accurate measurements\n", "print(\"Gather:\")\n", - "print(t_gather.timeit(number=1000)) # Run 1000 times for better accuracy\n", + "print(t_gather.timeit(number=400000)) # Run many times for better accuracy\n", "print(\"Integer Indexing:\")\n", - "print(t_indexing.timeit(number=1000))\n", + "print(t_indexing.timeit(number=400000))\n", + "print(\"Index Select:\")\n", + "print(t_index_select.timeit(number=400000))\n", "\n", "# Ensure outputs are the same\n", - "assert torch.allclose(index_using_gather(tensor.clone(), indices.clone()), index_using_integral_indexing(tensor.clone(), indices.clone()))\n" - ] + "assert torch.allclose(index_using_gather(tensor.clone(), indices.clone()), index_using_integral_indexing(tensor.clone(), indices.clone()))\n", + "assert torch.allclose(index_using_gather(tensor.clone(), indices.clone()), index_using_index_select(tensor.clone(), indices.clone()))\n" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Gather:\n" + ] + }, + { + "ename": "RuntimeError", + "evalue": "CUDA error: unspecified launch failure\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n", + "output_type": "error", + "traceback": [ + "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[0;31mRuntimeError\u001B[0m Traceback (most recent call last)", + "Cell \u001B[0;32mIn[68], line 17\u001B[0m\n\u001B[1;32m 15\u001B[0m \u001B[38;5;66;03m# Repeatedly run the timers for more accurate measurements\u001B[39;00m\n\u001B[1;32m 16\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mGather:\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[0;32m---> 17\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[43mt_gather\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mtimeit\u001B[49m\u001B[43m(\u001B[49m\u001B[43mnumber\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m400000\u001B[39;49m\u001B[43m)\u001B[49m) \u001B[38;5;66;03m# Run many times for better accuracy\u001B[39;00m\n\u001B[1;32m 18\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mInteger Indexing:\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[1;32m 19\u001B[0m \u001B[38;5;28mprint\u001B[39m(t_indexing\u001B[38;5;241m.\u001B[39mtimeit(number\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m400000\u001B[39m))\n", + "File \u001B[0;32m~/miniforge3/envs/test/lib/python3.10/site-packages/torch/utils/benchmark/utils/timer.py:274\u001B[0m, in \u001B[0;36mTimer.timeit\u001B[0;34m(self, number)\u001B[0m\n\u001B[1;32m 267\u001B[0m \u001B[38;5;250m\u001B[39m\u001B[38;5;124;03m\"\"\"Mirrors the semantics of timeit.Timer.timeit().\u001B[39;00m\n\u001B[1;32m 268\u001B[0m \n\u001B[1;32m 269\u001B[0m \u001B[38;5;124;03mExecute the main statement (`stmt`) `number` times.\u001B[39;00m\n\u001B[1;32m 270\u001B[0m \u001B[38;5;124;03mhttps://docs.python.org/3/library/timeit.html#timeit.Timer.timeit\u001B[39;00m\n\u001B[1;32m 271\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[1;32m 272\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m common\u001B[38;5;241m.\u001B[39mset_torch_threads(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_task_spec\u001B[38;5;241m.\u001B[39mnum_threads):\n\u001B[1;32m 273\u001B[0m \u001B[38;5;66;03m# Warmup\u001B[39;00m\n\u001B[0;32m--> 274\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_timeit\u001B[49m\u001B[43m(\u001B[49m\u001B[43mnumber\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43mmax\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mint\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43mnumber\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m/\u001B[39;49m\u001B[38;5;241;43m/\u001B[39;49m\u001B[43m \u001B[49m\u001B[38;5;241;43m100\u001B[39;49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m2\u001B[39;49m\u001B[43m)\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 276\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m common\u001B[38;5;241m.\u001B[39mMeasurement(\n\u001B[1;32m 277\u001B[0m number_per_run\u001B[38;5;241m=\u001B[39mnumber,\n\u001B[1;32m 278\u001B[0m raw_times\u001B[38;5;241m=\u001B[39m[\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_timeit(number\u001B[38;5;241m=\u001B[39mnumber)],\n\u001B[1;32m 279\u001B[0m task_spec\u001B[38;5;241m=\u001B[39m\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_task_spec\n\u001B[1;32m 280\u001B[0m )\n", + "File \u001B[0;32m~/miniforge3/envs/test/lib/python3.10/site-packages/torch/utils/benchmark/utils/timer.py:264\u001B[0m, in \u001B[0;36mTimer._timeit\u001B[0;34m(self, number)\u001B[0m\n\u001B[1;32m 261\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21m_timeit\u001B[39m(\u001B[38;5;28mself\u001B[39m, number: \u001B[38;5;28mint\u001B[39m) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m \u001B[38;5;28mfloat\u001B[39m:\n\u001B[1;32m 262\u001B[0m \u001B[38;5;66;03m# Even calling a timer in C++ takes ~50 ns, so no real operation should\u001B[39;00m\n\u001B[1;32m 263\u001B[0m \u001B[38;5;66;03m# take less than 1 ns. (And this prevents divide by zero errors.)\u001B[39;00m\n\u001B[0;32m--> 264\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mmax\u001B[39m(\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_timer\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mtimeit\u001B[49m\u001B[43m(\u001B[49m\u001B[43mnumber\u001B[49m\u001B[43m)\u001B[49m, \u001B[38;5;241m1e-9\u001B[39m)\n", + "File \u001B[0;32m~/miniforge3/envs/test/lib/python3.10/timeit.py:178\u001B[0m, in \u001B[0;36mTimer.timeit\u001B[0;34m(self, number)\u001B[0m\n\u001B[1;32m 176\u001B[0m gc\u001B[38;5;241m.\u001B[39mdisable()\n\u001B[1;32m 177\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m--> 178\u001B[0m timing \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43minner\u001B[49m\u001B[43m(\u001B[49m\u001B[43mit\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mtimer\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 179\u001B[0m \u001B[38;5;28;01mfinally\u001B[39;00m:\n\u001B[1;32m 180\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m gcold:\n", + "File \u001B[0;32m:4\u001B[0m, in \u001B[0;36minner\u001B[0;34m(_it, _timer)\u001B[0m\n", + "File \u001B[0;32m~/miniforge3/envs/test/lib/python3.10/site-packages/torch/utils/benchmark/utils/timer.py:18\u001B[0m, in \u001B[0;36mtimer\u001B[0;34m()\u001B[0m\n\u001B[1;32m 17\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mtimer\u001B[39m() \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m \u001B[38;5;28mfloat\u001B[39m:\n\u001B[0;32m---> 18\u001B[0m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mcuda\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43msynchronize\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 19\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m timeit\u001B[38;5;241m.\u001B[39mdefault_timer()\n", + "File \u001B[0;32m~/miniforge3/envs/test/lib/python3.10/site-packages/torch/cuda/__init__.py:792\u001B[0m, in \u001B[0;36msynchronize\u001B[0;34m(device)\u001B[0m\n\u001B[1;32m 790\u001B[0m _lazy_init()\n\u001B[1;32m 791\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m torch\u001B[38;5;241m.\u001B[39mcuda\u001B[38;5;241m.\u001B[39mdevice(device):\n\u001B[0;32m--> 792\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_C\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_cuda_synchronize\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n", + "\u001B[0;31mRuntimeError\u001B[0m: CUDA error: unspecified launch failure\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n" + ] + } + ], + "execution_count": 68 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-12T19:50:15.263575Z", + "start_time": "2024-06-12T19:50:15.259846Z" + } + }, + "cell_type": "code", + "source": "tensor.device", + "outputs": [ + { + "data": { + "text/plain": [ + "device(type='cuda', index=0)" + ] + }, + "execution_count": 62, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 62 } ], "metadata": { From 16538b62053842882bc3d03f19fce56c5b55b8fa Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Wed, 26 Jun 2024 08:10:46 -0700 Subject: [PATCH 07/78] Implement SpookyNet radial basis function. Test failing (probably due to unit conversion). --- modelforge/potential/utils.py | 86 ++++++++++++++++++++++++++++++ modelforge/tests/test_spookynet.py | 24 +++++++-- 2 files changed, 107 insertions(+), 3 deletions(-) diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index 4571dd95..f4e1cc6f 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -307,6 +307,82 @@ def forward(self, d_ij: torch.Tensor): ) +class ExponentialBernsteinPolynomials(nn.Module): + """ + Taken from SpookyNet. + Radial basis functions based on exponential Bernstein polynomials given by: + b_{v,n}(x) = (n over v) * exp(-alpha*x)**v * (1-exp(-alpha*x))**(n-v) + (see https://en.wikipedia.org/wiki/Bernstein_polynomial) + Here, n = num_basis_functions-1 and v takes values from 0 to n. This + implementation operates in log space to prevent multiplication of very large + (n over v) and very small numbers (exp(-alpha*x)**v and + (1-exp(-alpha*x))**(n-v)) for numerical stability. + NOTE: There is a problem for x = 0, as log(-expm1(0)) will be log(0) = -inf. + This itself is not an issue, but the buffer v contains an entry 0 and + 0*(-inf)=nan. The correct behaviour could be recovered by replacing the nan + with 0.0, but should not be necessary because issues are only present when + r = 0, which will not occur with chemically meaningful inputs. + + Arguments: + num_basis_functions (int): + Number of radial basis functions. + x = infinity. + ini_alpha (float): + Initial value for scaling parameter alpha (Default value corresponds + to 0.5 1/Bohr converted to 1/Angstrom). + """ + + def __init__( + self, + num_basis_functions: int, + ini_alpha: Quantity = 0.5 / unit.bohr, + dtype: Optional[torch.dtype] = None, + ) -> None: + """ Initializes the ExponentialBernsteinPolynomials class. """ + super(ExponentialBernsteinPolynomials, self).__init__() + self.ini_alpha = ini_alpha.to(1 / unit.nanometer).m + # compute values to initialize buffers + logfactorial = np.zeros(num_basis_functions) + for i in range(2, num_basis_functions): + logfactorial[i] = logfactorial[i - 1] + np.log(i) + v = np.arange(0, num_basis_functions) + n = (num_basis_functions - 1) - v + logbinomial = logfactorial[-1] - logfactorial[v] - logfactorial[n] + # register buffers and parameters + self.register_buffer("logc", torch.tensor(logbinomial, dtype=dtype)) + self.register_buffer("n", torch.tensor(n, dtype=dtype)) + self.register_buffer("v", torch.tensor(v, dtype=dtype)) + self.register_parameter( + "_alpha", nn.Parameter(torch.tensor(1.0, dtype=dtype)) + ) + self.reset_parameters() + + def reset_parameters(self) -> None: + """ Initialize exponential scaling parameter alpha. """ + nn.init.constant_(self._alpha, softplus_inverse(self.ini_alpha)) + + def forward(self, r: torch.Tensor) -> torch.Tensor: + """ + Evaluates radial basis functions given distances + N: Number of input values. + num_basis_functions: Number of radial basis functions. + + Arguments: + r (FloatTensor [N]): + Input distances. + + Returns: + rbf (FloatTensor [N, num_basis_functions]): + Values of the radial basis functions for the distances r. + """ + alphar = -F.softplus(self._alpha) * r.view(-1, 1) + x = self.logc + self.n * alphar + self.v * torch.log(-torch.expm1(alphar)) + print(f"{self.logc.shape=}") + + rbf = torch.exp(x) + return rbf * torch.exp(alphar) + + from typing import Dict @@ -336,6 +412,16 @@ def forward(self, x: torch.Tensor): return functional.softplus(x) - self.log_2 +def softplus_inverse(x): + """ + From SpookyNet: + Inverse of the softplus function. This is useful for initialization of + parameters that are constrained to be positive (via softplus). + """ + if not isinstance(x, torch.Tensor): + x = torch.tensor(x) + return x + torch.log(-torch.expm1(-x)) + class AngularSymmetryFunction(nn.Module): """ diff --git a/modelforge/tests/test_spookynet.py b/modelforge/tests/test_spookynet.py index 19370ec5..69d93614 100644 --- a/modelforge/tests/test_spookynet.py +++ b/modelforge/tests/test_spookynet.py @@ -166,15 +166,16 @@ def test_spookynet_interaction_module_against_reference(): mf_param.requires_grad = False mf_param[:] = ref_param - assert len(list(mf_spookynet_interaction_module.resblock.named_parameters())) == len(list(ref_spookynet_interaction_module.resblock.named_parameters())) - for (mf_name, mf_param), (ref_name, ref_param) in zip(mf_spookynet_interaction_module.resblock.named_parameters(), ref_spookynet_interaction_module.resblock.named_parameters()): + assert len(list(mf_spookynet_interaction_module.resblock.named_parameters())) == len( + list(ref_spookynet_interaction_module.resblock.named_parameters())) + for (mf_name, mf_param), (ref_name, ref_param) in zip(mf_spookynet_interaction_module.resblock.named_parameters(), + ref_spookynet_interaction_module.resblock.named_parameters()): print(f"{mf_name=} {ref_name=}") if not torch.equal(mf_param, ref_param): print(f"{mf_param=} {ref_param=}") else: print("parameters are the same") - x = torch.rand((N, num_features), dtype=torch.double) f_ij = torch.rand((P, number_of_radial_basis_functions), dtype=torch.double) f_ij_cutoff = torch.rand((P, 1), dtype=torch.double) @@ -188,3 +189,20 @@ def test_spookynet_interaction_module_against_reference(): d_orbital_ij, idx_i, idx_j, 1, None, None) assert torch.equal(mf_x_result, ref_x_result) assert torch.equal(mf_y_result, ref_y_result) + + +def test_spookynet_bernstein_polynomial_equivalence(): + from spookynet.modules.exponential_bernstein_polynomials import ExponentialBernsteinPolynomials as RefExponentialBernsteinPolynomials + from modelforge.potential.utils import ExponentialBernsteinPolynomials as MfExponentialBernSteinPolynomials + + num_basis_functions = 3 + ref_exp_bernstein_polynomials = RefExponentialBernsteinPolynomials(num_basis_functions, exp_weighting=True) + mf_exp_bernstein_polynomials = MfExponentialBernSteinPolynomials(num_basis_functions) + + N = 5 + r_angstrom = torch.rand((N, 1)) + r_nanometer = r_angstrom * 0.1 + cutoff_values = torch.rand((N, 1)) + ref_exp_bernstein_polynomial_result = ref_exp_bernstein_polynomials(r_angstrom, cutoff_values) + mf_exp_bernstein_polynomial_result = mf_exp_bernstein_polynomials(r_nanometer) * cutoff_values + assert torch.equal(ref_exp_bernstein_polynomial_result, mf_exp_bernstein_polynomial_result) From 9a210b80f12a52ed9bfab148ebc9c310cda16784 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Wed, 26 Jun 2024 22:49:07 -0700 Subject: [PATCH 08/78] Begin refactoring radial basis functions --- modelforge/potential/__init__.py | 2 +- modelforge/potential/sake.py | 4 +- modelforge/potential/utils.py | 317 +++++++++++++++++-------------- modelforge/tests/test_sake.py | 4 +- modelforge/tests/test_utils.py | 6 +- 5 files changed, 179 insertions(+), 154 deletions(-) diff --git a/modelforge/potential/__init__.py b/modelforge/potential/__init__.py index c3f3a6f0..9a558fc2 100644 --- a/modelforge/potential/__init__.py +++ b/modelforge/potential/__init__.py @@ -5,7 +5,7 @@ from .sake import SAKE from .utils import ( CosineCutoff, - RadialSymmetryFunction, + RadialBasisFunction, AngularSymmetryFunction, ) from .processing import FromAtomToMoleculeReduction diff --git a/modelforge/potential/sake.py b/modelforge/potential/sake.py index 921b1be2..fc5ca994 100644 --- a/modelforge/potential/sake.py +++ b/modelforge/potential/sake.py @@ -10,7 +10,7 @@ Dense, scatter_softmax, SAKERadialSymmetryFunction, - SAKERadialBasisFunction, + SAKERadialBasisFunctionCore, ) from modelforge.dataset.dataset import NNPInput import torch @@ -239,7 +239,7 @@ def __init__( max_distance=cutoff, dtype=torch.float32, trainable=False, - radial_basis_function=SAKERadialBasisFunction(0.0 * unit.nanometer), + radial_basis_function=SAKERadialBasisFunctionCore(0.0 * unit.nanometer), ) self.node_mlp = nn.Sequential( diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index e7c37b02..81fdc82a 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -10,6 +10,7 @@ from typing import Union from modelforge.dataset.dataset import NNPInput + @dataclass class NeuralNetworkData: pair_indices: torch.Tensor @@ -22,7 +23,6 @@ class NeuralNetworkData: total_charge: torch.Tensor - import torch @@ -42,7 +42,7 @@ class Metadata: F: torch.Tensor = torch.tensor([], dtype=torch.float32) def to( - self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None + self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None ): """Move all tensors in this instance to the specified device.""" if device: @@ -64,9 +64,9 @@ class BatchData: metadata: Metadata def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): self.nnp_input = self.nnp_input.to(device=device, dtype=dtype) self.metadata = self.metadata.to(device=device, dtype=dtype) @@ -85,7 +85,7 @@ def shared_config_prior(): def triple_by_molecule( - atom_pairs: torch.Tensor, + atom_pairs: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Input: indices for pairs of atoms that are close to each other. each pair only appear once, i.e. only one of the pairs (1, 2) and @@ -121,8 +121,8 @@ def cumsum_from_zero(input_: torch.Tensor) -> torch.Tensor: torch.tril_indices(m, m, -1, device=ai1.device).unsqueeze(1).expand(-1, n, -1) ) mask = ( - torch.arange(intra_pair_indices.shape[2], device=ai1.device) - < pair_sizes.unsqueeze(1) + torch.arange(intra_pair_indices.shape[2], device=ai1.device) + < pair_sizes.unsqueeze(1) ).flatten() sorted_local_index12 = intra_pair_indices.flatten(1, 2)[:, mask] sorted_local_index12 += cumsum_from_zero(counts).index_select(0, pair_indices) @@ -199,13 +199,13 @@ class Dense(nn.Linear): """ def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - activation: Optional[nn.Module] = None, - weight_init: Callable = xavier_uniform_, - bias_init: Callable = zeros_, + self, + in_features: int, + out_features: int, + bias: bool = True, + activation: Optional[nn.Module] = None, + weight_init: Callable = xavier_uniform_, + bias_init: Callable = zeros_, ): """ Args: @@ -267,7 +267,7 @@ def forward(self, d_ij: torch.Tensor): """ # Compute values of cutoff function input_cut = 0.5 * ( - torch.cos(d_ij * np.pi / self.cutoff) + 1.0 + torch.cos(d_ij * np.pi / self.cutoff) + 1.0 ) # NOTE: ANI adds 0.5 instead of 1. # Remove contributions beyond the cutoff radius input_cut *= (d_ij < self.cutoff).float() @@ -311,13 +311,13 @@ class AngularSymmetryFunction(nn.Module): """ def __init__( - self, - max_distance: unit.Quantity, - min_distance: unit.Quantity, - number_of_gaussians_for_asf: int = 8, - angle_sections: int = 4, - trainable: bool = False, - dtype: Optional[torch.dtype] = None, + self, + max_distance: unit.Quantity, + min_distance: unit.Quantity, + number_of_gaussians_for_asf: int = 8, + angle_sections: int = 4, + trainable: bool = False, + dtype: Optional[torch.dtype] = None, ) -> None: """ Parameters @@ -424,43 +424,84 @@ def compute_angular_sub_aev(self, vectors12: torch.Tensor) -> torch.Tensor: from abc import ABC, abstractmethod -class RadialBasisFunction(ABC): +class RadialBasisFunctionCore(ABC): + @staticmethod @abstractmethod - def compute(self, distances, centers, scale_factors): + def compute(nondimensionalized_distances): + """ + Parameters + --------- + nondimensionalized_distances: torch.Tensor, shape [number_of_pairs, number_of_radial_basis_functions] + Nondimensional quantities that range from 0 to infinity. All distances should be nondimensionalized + appropriately before passing as input to this function. + + Returns + --------- + torch.Tensor, shape [number_of_pairs, number_of_radial_basis_functions] + """ pass -class GaussianRadialBasisFunction(RadialBasisFunction): +class GaussianRadialBasisFunctionCore(RadialBasisFunctionCore): + + @staticmethod def compute( - self, - distances: torch.Tensor, - centers: torch.Tensor, - scale_factors: torch.Tensor, + nondimensionalized_distances: torch.Tensor, ) -> torch.Tensor: - diff = distances - centers - return torch.exp((-1 * scale_factors) * diff**2) + return torch.exp(-nondimensionalized_distances ** 2) + +class DoubleExponentialRadialBasisFunctionCore(RadialBasisFunctionCore): -class DoubleExponentialRadialBasisFunction(RadialBasisFunction): + @staticmethod def compute( - self, - distances: torch.Tensor, - centers: torch.Tensor, - scale_factors: torch.Tensor, + nondimensionalized_distances: torch.Tensor, ) -> torch.Tensor: - diff = distances - centers - return torch.exp(-torch.abs(diff / scale_factors)) + return torch.exp(-torch.abs(nondimensionalized_distances)) + +class RadialBasisFunction(nn.Module, ABC): -class RadialSymmetryFunction(nn.Module): + def __init__(self, radial_basis_function: RadialBasisFunctionCore, trainable: bool = False): + super().__init__() + if trainable: + self.prefactor = nn.Parameter(torch.tensor([1.0])) + else: + self.register_buffer("prefactor", torch.tensor([1.0])) + self.radial_basis_function = radial_basis_function + + @abstractmethod + def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: + """ + Parameters + --------- + distances: torch.Tensor, shape [number_of_pairs, 1] + Distances between atoms in each pair in nanometers. + + Returns + --------- + torch.Tensor, shape [number_of_pairs, number_of_radial_basis_functions] + Nondimensional quantities that range from 0 to infinity. + """ + pass + + def forward(self, distances: torch.Tensor) -> torch.Tensor: + """ + Applies nondimensionalization transformations on the distances and passes the result to RadialBasisFunctionCore. + """ + nondimensionalized_distances = self.nondimensionalize_distances(distances) + return self.prefactor * self.radial_basis_function.compute(nondimensionalized_distances) + + +class RadialBasisFunctionWithCenters(RadialBasisFunction): def __init__( - self, - number_of_radial_basis_functions: int, - max_distance: unit.Quantity, - min_distance: unit.Quantity = 0.0 * unit.nanometer, - dtype: Optional[torch.dtype] = None, - trainable: bool = False, - radial_basis_function: RadialBasisFunction = GaussianRadialBasisFunction(), + self, + number_of_radial_basis_functions: int, + max_distance: unit.Quantity, + min_distance: unit.Quantity = 0.0 * unit.nanometer, + dtype: Optional[torch.dtype] = None, + trainable: bool = False, + radial_basis_function: RadialBasisFunctionCore = GaussianRadialBasisFunctionCore(), ): """RadialSymmetryFunction class. @@ -484,7 +525,7 @@ def __init__( symmetry function output given an input distance matrix. """ - super().__init__() + super().__init__(radial_basis_function, trainable) self.number_of_radial_basis_functions = number_of_radial_basis_functions self.max_distance = max_distance self.min_distance = min_distance @@ -518,18 +559,16 @@ def initialize_parameters(self): if self.trainable: self.radial_basis_centers = radial_basis_centers self.radial_scale_factor = radial_scale_factor - self.prefactor = nn.Parameter(torch.tensor([1.0])) else: self.register_buffer("radial_basis_centers", radial_basis_centers) self.register_buffer("radial_scale_factor", radial_scale_factor) - self.register_buffer("prefactor", torch.tensor([1.0])) def calculate_radial_basis_centers( - self, - _min_distance_in_nanometer, - _max_distance_in_nanometer, - number_of_radial_basis_functions, - dtype, + self, + _min_distance_in_nanometer, + _max_distance_in_nanometer, + number_of_radial_basis_functions, + dtype, ): # the default approach to calculate radial basis centers # can be overwritten by subclasses @@ -542,10 +581,10 @@ def calculate_radial_basis_centers( return centers def calculate_radial_scale_factor( - self, - _min_distance_in_nanometer, - _max_distance_in_nanometer, - number_of_radial_basis_functions, + self, + _min_distance_in_nanometer, + _max_distance_in_nanometer, + number_of_radial_basis_functions, ): # the default approach to calculate radial scale factors (each of them are scaled by the same value) # can be overwritten by subclasses @@ -557,34 +596,21 @@ def calculate_radial_scale_factor( scale_factors = scale_factors * -15_000 return scale_factors - def forward(self, d_ij: torch.Tensor) -> torch.Tensor: - """ - Compute the radial symmetry function values for each distance in d_ij. + def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: + diff = distances - self.centers + return diff / self.scale_factors - Parameters - ---------- - d_ij: torch.Tensor - pairwise distances with shape [N, 1] where N is the number of pairs. - Returns: - torch.Tensor, - tensor of radial symmetry function values with shape [N, num_basis_functions]. - """ - features = self.radial_basis_function.compute( - d_ij, self.radial_basis_centers, self.radial_scale_factor - ) - return self.prefactor * features - -class SchnetRadialSymmetryFunction(RadialSymmetryFunction): +class SchnetRadialSymmetryFunction(RadialBasisFunctionWithCenters): def __init__( - self, - number_of_radial_basis_functions: int, - max_distance: unit.Quantity, - min_distance: unit.Quantity = 0.0 * unit.nanometer, - dtype: Optional[torch.dtype] = None, - trainable: bool = False, - radial_basis_function: RadialBasisFunction = GaussianRadialBasisFunction(), + self, + number_of_radial_basis_functions: int, + max_distance: unit.Quantity, + min_distance: unit.Quantity = 0.0 * unit.nanometer, + dtype: Optional[torch.dtype] = None, + trainable: bool = False, + radial_basis_function: RadialBasisFunctionCore = GaussianRadialBasisFunctionCore(), ): """RadialSymmetryFunction class. @@ -605,10 +631,10 @@ def __init__( self.prefactor = torch.tensor([1.0]) def calculate_radial_scale_factor( - self, - _min_distance_in_nanometer, - _max_distance_in_nanometer, - number_of_radial_basis_functions, + self, + _min_distance_in_nanometer, + _max_distance_in_nanometer, + number_of_radial_basis_functions, ): scale_factors = torch.linspace( _min_distance_in_nanometer, @@ -617,23 +643,23 @@ def calculate_radial_scale_factor( ) widths = ( - torch.abs(scale_factors[1] - scale_factors[0]) - * torch.ones_like(scale_factors) + torch.abs(scale_factors[1] - scale_factors[0]) + * torch.ones_like(scale_factors) ).to(self.dtype) scale_factors = 0.5 / torch.square_(widths) return scale_factors -class AniRadialSymmetryFunction(RadialSymmetryFunction): +class AniRadialSymmetryFunction(RadialBasisFunctionWithCenters): def __init__( - self, - number_of_radial_basis_functions: int, - max_distance: unit.Quantity, - min_distance: unit.Quantity = 0.0 * unit.nanometer, - dtype: Optional[torch.dtype] = None, - trainable: bool = False, - radial_basis_function: RadialBasisFunction = GaussianRadialBasisFunction(), + self, + number_of_radial_basis_functions: int, + max_distance: unit.Quantity, + min_distance: unit.Quantity = 0.0 * unit.nanometer, + dtype: Optional[torch.dtype] = None, + trainable: bool = False, + radial_basis_function: RadialBasisFunctionCore = GaussianRadialBasisFunctionCore(), ): """RadialSymmetryFunction class. @@ -654,11 +680,11 @@ def __init__( self.prefactor = torch.tensor([0.25]) def calculate_radial_basis_centers( - self, - _min_distance_in_nanometer, - _max_distance_in_nanometer, - number_of_radial_basis_functions, - dtype, + self, + _min_distance_in_nanometer, + _max_distance_in_nanometer, + number_of_radial_basis_functions, + dtype, ): centers = torch.linspace( _min_distance_in_nanometer, @@ -669,23 +695,23 @@ def calculate_radial_basis_centers( return centers def calculate_radial_scale_factor( - self, - _min_distance_in_nanometer, - _max_distance_in_nanometer, - number_of_radial_basis_functions, + self, + _min_distance_in_nanometer, + _max_distance_in_nanometer, + number_of_radial_basis_functions, ): # ANI uses a predefined scaling factor scale_factors = torch.full((number_of_radial_basis_functions,), (19.7 * 100)) return scale_factors -class SAKERadialSymmetryFunction(RadialSymmetryFunction): +class SAKERadialSymmetryFunction(RadialBasisFunctionWithCenters): def calculate_radial_basis_centers( - self, - _min_distance_in_nanometer, - _max_distance_in_nanometer, - number_of_radial_basis_functions, - dtype, + self, + _min_distance_in_nanometer, + _max_distance_in_nanometer, + number_of_radial_basis_functions, + dtype, ): # initialize means and betas according to the default values in PhysNet # https://pubs.acs.org/doi/10.1021/acs.jctc.9b00181 @@ -702,10 +728,10 @@ def calculate_radial_basis_centers( return centers def calculate_radial_scale_factor( - self, - _min_distance_in_nanometer, - _max_distance_in_nanometer, - number_of_radial_basis_functions, + self, + _min_distance_in_nanometer, + _max_distance_in_nanometer, + number_of_radial_basis_functions, ): start_value = torch.exp( torch.scalar_tensor( @@ -721,26 +747,25 @@ def calculate_radial_scale_factor( return radial_scale_factor -class SAKERadialBasisFunction(RadialBasisFunction): +class SAKERadialBasisFunctionCore(RadialBasisFunctionCore): def __init__(self, min_distance): super().__init__() self._min_distance_in_nanometer = min_distance.to(unit.nanometer).m def compute( - self, - distances: torch.Tensor, - centers: torch.Tensor, - scale_factors: torch.Tensor, + self, + distances: torch.Tensor, + centers: torch.Tensor, + scale_factors: torch.Tensor, ) -> torch.Tensor: - return torch.exp( -scale_factors * ( - torch.exp( - (-distances.unsqueeze(-1) + self._min_distance_in_nanometer) * 10 - ) - - centers + torch.exp( + (-distances.unsqueeze(-1) + self._min_distance_in_nanometer) * 10 + ) + - centers ) ** 2 ) @@ -749,13 +774,13 @@ def compute( class PhysNetRadialSymmetryFunction(SAKERadialSymmetryFunction): def __init__( - self, - number_of_radial_basis_functions: int, - max_distance: unit.Quantity, - min_distance: unit.Quantity = 0.0 * unit.nanometer, - dtype: Optional[torch.dtype] = None, - trainable: bool = False, - radial_basis_function: Optional[SAKERadialBasisFunction] = None, + self, + number_of_radial_basis_functions: int, + max_distance: unit.Quantity, + min_distance: unit.Quantity = 0.0 * unit.nanometer, + dtype: Optional[torch.dtype] = None, + trainable: bool = False, + radial_basis_function: Optional[SAKERadialBasisFunctionCore] = None, ): """RadialSymmetryFunction class. @@ -766,7 +791,7 @@ def __init__( """ # Create the radial_basis_function if not provided if radial_basis_function is None: - radial_basis_function = SAKERadialBasisFunction(min_distance) + radial_basis_function = SAKERadialBasisFunctionCore(min_distance) super().__init__( number_of_radial_basis_functions, @@ -780,8 +805,8 @@ def __init__( def pair_list( - atomic_subsystem_indices: torch.Tensor, - only_unique_pairs: bool = False, + atomic_subsystem_indices: torch.Tensor, + only_unique_pairs: bool = False, ) -> torch.Tensor: """Compute all pairs of atoms and their distances. @@ -822,7 +847,7 @@ def pair_list( # filter pairs to only keep those belonging to the same molecule same_molecule_mask = ( - atomic_subsystem_indices[i_indices] == atomic_subsystem_indices[j_indices] + atomic_subsystem_indices[i_indices] == atomic_subsystem_indices[j_indices] ) # Apply mask to get final pair indices @@ -835,9 +860,9 @@ def pair_list( return pair_indices.to(device) def forward( - self, - coordinates: torch.Tensor, # in nanometer - atomic_subsystem_indices: torch.Tensor, + self, + coordinates: torch.Tensor, # in nanometer + atomic_subsystem_indices: torch.Tensor, ) -> torch.Tensor: """Compute all pairs of atoms and their distances. @@ -869,11 +894,11 @@ def forward( def scatter_softmax( - src: torch.Tensor, - index: torch.Tensor, - dim: int, - dim_size: Optional[int] = None, - device: Optional[torch.device] = None, + src: torch.Tensor, + index: torch.Tensor, + dim: int, + dim_size: Optional[int] = None, + device: Optional[torch.device] = None, ) -> torch.Tensor: """ Softmax operation over all values in :attr:`src` tensor that share indices @@ -907,7 +932,7 @@ def scatter_softmax( assert dim >= 0, f"dim must be non-negative, got {dim}" assert ( - dim < src.dim() + dim < src.dim() ), f"dim must be less than the number of dimensions of src {src.dim()}, got {dim}" out_shape = [ diff --git a/modelforge/tests/test_sake.py b/modelforge/tests/test_sake.py index 94e47242..341ca2df 100644 --- a/modelforge/tests/test_sake.py +++ b/modelforge/tests/test_sake.py @@ -241,7 +241,7 @@ def make_equivalent_pairlist_mask(key, nr_atoms, nr_pairs, include_self_pairs): def test_radial_symmetry_function_against_reference(): from modelforge.potential.utils import ( SAKERadialSymmetryFunction, - SAKERadialBasisFunction, + SAKERadialBasisFunctionCore, ) from sake.utils import ExpNormalSmearing as RefExpNormalSmearing @@ -258,7 +258,7 @@ def test_radial_symmetry_function_against_reference(): min_distance=cutoff_lower, dtype=torch.float32, trainable=False, - radial_basis_function=SAKERadialBasisFunction(cutoff_lower), + radial_basis_function=SAKERadialBasisFunctionCore(cutoff_lower), ) ref_radial_basis_module = RefExpNormalSmearing( num_rbf=number_of_radial_basis_functions, diff --git a/modelforge/tests/test_utils.py b/modelforge/tests/test_utils.py index a91be3eb..9b5b4eb9 100644 --- a/modelforge/tests/test_utils.py +++ b/modelforge/tests/test_utils.py @@ -2,7 +2,7 @@ import torch import pytest -from modelforge.potential.utils import CosineCutoff, RadialSymmetryFunction +from modelforge.potential.utils import CosineCutoff, RadialBasisFunction def test_dense_layer(): @@ -116,13 +116,13 @@ def test_radial_symmetry_function_implementation(): """ Test the Radial Symmetry function implementation. """ - from modelforge.potential.utils import RadialSymmetryFunction, CosineCutoff + from modelforge.potential.utils import RadialBasisFunction, CosineCutoff import torch from openff.units import unit import numpy as np cutoff_module = CosineCutoff(cutoff=unit.Quantity(5.0, unit.angstrom)) - RSF = RadialSymmetryFunction( + RSF = RadialBasisFunction( number_of_radial_basis_functions=18, max_distance=unit.Quantity(5.0, unit.angstrom), ) From 6e7b10ecc9bf80634ff21750edb71b1aeaeba05e Mon Sep 17 00:00:00 2001 From: Arnav Nagle <43835955+ArnNag@users.noreply.github.com> Date: Thu, 27 Jun 2024 08:20:05 -0700 Subject: [PATCH 09/78] Update RadialBasisFunction comment --- modelforge/potential/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index 81fdc82a..0fab9d70 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -432,8 +432,7 @@ def compute(nondimensionalized_distances): Parameters --------- nondimensionalized_distances: torch.Tensor, shape [number_of_pairs, number_of_radial_basis_functions] - Nondimensional quantities that range from 0 to infinity. All distances should be nondimensionalized - appropriately before passing as input to this function. + Nondimensional quantities that depend on pairwise distances. Returns --------- From b6f1149b402087383946438257dc585e4fb2fb4c Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Thu, 27 Jun 2024 10:16:25 -0700 Subject: [PATCH 10/78] Fix references to radial basis centers and scale factors --- modelforge/potential/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index 81fdc82a..3efabce3 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -597,8 +597,8 @@ def calculate_radial_scale_factor( return scale_factors def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: - diff = distances - self.centers - return diff / self.scale_factors + diff = distances - self.radial_basis_centers + return diff / self.radial_scale_factor From 7757010174a40b043a049b7ef90ab97ff90d071f Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Thu, 27 Jun 2024 16:35:40 -0700 Subject: [PATCH 11/78] Continue refactoring radial basis functions --- modelforge/potential/models.py | 1 + modelforge/potential/painn.py | 4 +- modelforge/potential/physnet.py | 5 +- modelforge/potential/sake.py | 9 +- modelforge/potential/schnet.py | 4 +- modelforge/potential/utils.py | 308 +++++++++++++++++--------------- modelforge/tests/test_nn.py | 4 +- modelforge/tests/test_sake.py | 4 +- modelforge/tests/test_spk.py | 4 +- 9 files changed, 178 insertions(+), 165 deletions(-) diff --git a/modelforge/potential/models.py b/modelforge/potential/models.py index 7a5437b4..3d0edd67 100644 --- a/modelforge/potential/models.py +++ b/modelforge/potential/models.py @@ -731,6 +731,7 @@ def forward(self, data: NNPInput): self.input_preparation._input_checks(data) # prepare the input for the forward pass pairlist_output = self.input_preparation.prepare_inputs(data) + print(f"{type(self)}: {pairlist_output.d_ij.shape=}") return self.core_module(data, pairlist_output) diff --git a/modelforge/potential/painn.py b/modelforge/potential/painn.py index 3ed8806d..d60bae7e 100644 --- a/modelforge/potential/painn.py +++ b/modelforge/potential/painn.py @@ -234,9 +234,9 @@ def __init__( self.cutoff_module = CosineCutoff(cutoff) # radial symmetry function - from .utils import SchnetRadialSymmetryFunction + from .utils import SchnetRadialBasisFunction - self.radial_symmetry_function_module = SchnetRadialSymmetryFunction( + self.radial_symmetry_function_module = SchnetRadialBasisFunction( number_of_radial_basis_functions=number_of_radial_basis_functions, max_distance=cutoff, dtype=torch.float32, diff --git a/modelforge/potential/physnet.py b/modelforge/potential/physnet.py index ac99bdaa..7aca970d 100644 --- a/modelforge/potential/physnet.py +++ b/modelforge/potential/physnet.py @@ -90,9 +90,9 @@ def __init__( self.cutoff_module = CosineCutoff(cutoff) # radial symmetry function - from .utils import PhysNetRadialSymmetryFunction + from .utils import PhysNetRadialBasisFunction - self.radial_symmetry_function_module = PhysNetRadialSymmetryFunction( + self.radial_symmetry_function_module = PhysNetRadialBasisFunction( number_of_radial_basis_functions=number_of_radial_basis_functions, max_distance=cutoff, dtype=torch.float32, @@ -469,6 +469,7 @@ def _model_specific_input_preparation( ) -> PhysNetNeuralNetworkData: # Perform atomic embedding + print(f"Physnet {pairlist_output.d_ij.shape=}") atomic_embedding = self.embedding_module(data.atomic_numbers) # Z_i, ..., Z_N # diff --git a/modelforge/potential/sake.py b/modelforge/potential/sake.py index fc5ca994..db3e0ebd 100644 --- a/modelforge/potential/sake.py +++ b/modelforge/potential/sake.py @@ -9,8 +9,7 @@ from .utils import ( Dense, scatter_softmax, - SAKERadialSymmetryFunction, - SAKERadialBasisFunctionCore, + PhysNetRadialBasisFunction, ) from modelforge.dataset.dataset import NNPInput import torch @@ -234,12 +233,10 @@ def __init__( self.nr_coefficients = nr_coefficients self.nr_heads = nr_heads self.epsilon = epsilon - self.radial_symmetry_function_module = SAKERadialSymmetryFunction( + self.radial_symmetry_function_module = PhysNetRadialBasisFunction( number_of_radial_basis_functions=number_of_radial_basis_functions, max_distance=cutoff, dtype=torch.float32, - trainable=False, - radial_basis_function=SAKERadialBasisFunctionCore(0.0 * unit.nanometer), ) self.node_mlp = nn.Sequential( @@ -324,7 +321,7 @@ def update_edge(self, h_i_by_pair, h_j_by_pair, d_ij): Intermediate edge features. Shape [nr_pairs, nr_edge_basis]. """ h_ij_cat = torch.cat([h_i_by_pair, h_j_by_pair], dim=-1) - h_ij_filtered = self.radial_symmetry_function_module(d_ij) * self.edge_mlp_in( + h_ij_filtered = self.radial_symmetry_function_module(d_ij.unsqueeze(-1)) * self.edge_mlp_in( h_ij_cat ) return self.edge_mlp_out( diff --git a/modelforge/potential/schnet.py b/modelforge/potential/schnet.py index 3e63ebcd..76e5192b 100644 --- a/modelforge/potential/schnet.py +++ b/modelforge/potential/schnet.py @@ -334,9 +334,9 @@ def __init__( def _setup_radial_symmetry_functions( self, radial_cutoff: unit.Quantity, number_of_radial_basis_functions: int ): - from .utils import SchnetRadialSymmetryFunction + from .utils import SchnetRadialBasisFunction - radial_symmetry_function = SchnetRadialSymmetryFunction( + radial_symmetry_function = SchnetRadialBasisFunction( number_of_radial_basis_functions=number_of_radial_basis_functions, max_distance=radial_cutoff, dtype=torch.float32, diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index 3efabce3..2cdcd5cd 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -1,5 +1,6 @@ +import math from dataclasses import dataclass, field -from typing import Any, Callable, Optional, Tuple, NamedTuple +from typing import Any, Callable, Optional, Tuple, NamedTuple, Type import numpy as np import torch @@ -427,7 +428,7 @@ def compute_angular_sub_aev(self, vectors12: torch.Tensor) -> torch.Tensor: class RadialBasisFunctionCore(ABC): @staticmethod @abstractmethod - def compute(nondimensionalized_distances): + def compute(nondimensionalized_distances: torch.Tensor) -> torch.Tensor: """ Parameters --------- @@ -445,29 +446,23 @@ def compute(nondimensionalized_distances): class GaussianRadialBasisFunctionCore(RadialBasisFunctionCore): @staticmethod - def compute( - nondimensionalized_distances: torch.Tensor, - ) -> torch.Tensor: - return torch.exp(-nondimensionalized_distances ** 2) - - -class DoubleExponentialRadialBasisFunctionCore(RadialBasisFunctionCore): - - @staticmethod - def compute( - nondimensionalized_distances: torch.Tensor, - ) -> torch.Tensor: - return torch.exp(-torch.abs(nondimensionalized_distances)) + def compute(nondimensionalized_distances: torch.Tensor) -> torch.Tensor: + return torch.exp(-(nondimensionalized_distances ** 2)) class RadialBasisFunction(nn.Module, ABC): - def __init__(self, radial_basis_function: RadialBasisFunctionCore, trainable: bool = False): + def __init__( + self, + radial_basis_function: Type[RadialBasisFunctionCore], + dtype, + trainable_prefactor: bool = False, + ): super().__init__() - if trainable: - self.prefactor = nn.Parameter(torch.tensor([1.0])) + if trainable_prefactor: + self.prefactor = nn.Parameter(torch.tensor([1.0], dtype=dtype)) else: - self.register_buffer("prefactor", torch.tensor([1.0])) + self.register_buffer("prefactor", torch.tensor([1.0], dtype=dtype)) self.radial_basis_function = radial_basis_function @abstractmethod @@ -481,27 +476,41 @@ def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: Returns --------- torch.Tensor, shape [number_of_pairs, number_of_radial_basis_functions] - Nondimensional quantities that range from 0 to infinity. + Nondimensional quantities computed from the distances. """ pass def forward(self, distances: torch.Tensor) -> torch.Tensor: """ Applies nondimensionalization transformations on the distances and passes the result to RadialBasisFunctionCore. + + Parameters + --------- + distances: torch.Tensor, shape [number_of_pairs, 1] + Distances between atoms in each pair in nanometers. + + Returns + --------- + torch.Tensor, shape [number_of_pairs, number_of_radial_basis_functions] + Output of radial basis functions. """ nondimensionalized_distances = self.nondimensionalize_distances(distances) - return self.prefactor * self.radial_basis_function.compute(nondimensionalized_distances) + print(f"{nondimensionalized_distances.size()=}: {type(self)}") + return self.prefactor * self.radial_basis_function.compute( + nondimensionalized_distances + ) class RadialBasisFunctionWithCenters(RadialBasisFunction): def __init__( self, + radial_basis_function: Type[RadialBasisFunctionCore], number_of_radial_basis_functions: int, max_distance: unit.Quantity, min_distance: unit.Quantity = 0.0 * unit.nanometer, dtype: Optional[torch.dtype] = None, - trainable: bool = False, - radial_basis_function: RadialBasisFunctionCore = GaussianRadialBasisFunctionCore(), + trainable_prefactor: bool = False, + trainable_centers_and_scale_factors: bool = False, ): """RadialSymmetryFunction class. @@ -517,20 +526,22 @@ def __init__( Minimum distance to consider. dtype: Data type for computations. - trainable: bool, default False - Whether parameters are trainable. + trainable_prefactor: bool, default False + Whether prefactor is trainable + trainable_centers_and_scale_factors: bool, default False + Whether centers and scale factors are trainable. radial_basis_function: RadialBasisFunction, default GaussianRadialBasisFunction() Subclasses must implement the forward() method to compute the actual symmetry function output given an input distance matrix. """ - super().__init__(radial_basis_function, trainable) + super().__init__(radial_basis_function, dtype, trainable_prefactor) self.number_of_radial_basis_functions = number_of_radial_basis_functions self.max_distance = max_distance self.min_distance = min_distance self.dtype = dtype - self.trainable = trainable + self.trainable_centers_and_scale_factors = trainable_centers_and_scale_factors self.radial_basis_function = radial_basis_function self.initialize_parameters() # The length of radial subaev of a single species @@ -539,78 +550,66 @@ def __init__( def initialize_parameters(self): # convert to nanometer _max_distance_in_nanometer = self.max_distance.to(unit.nanometer).m + print(f"{_max_distance_in_nanometer=}") _min_distance_in_nanometer = self.min_distance.to(unit.nanometer).m # calculate radial basis centers radial_basis_centers = self.calculate_radial_basis_centers( - _min_distance_in_nanometer, - _max_distance_in_nanometer, self.number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, self.dtype, ) # calculate scale factors radial_scale_factor = self.calculate_radial_scale_factor( - _min_distance_in_nanometer, - _max_distance_in_nanometer, self.number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + self.dtype ) # either add as parameters or register buffers - if self.trainable: + if self.trainable_centers_and_scale_factors: self.radial_basis_centers = radial_basis_centers self.radial_scale_factor = radial_scale_factor else: self.register_buffer("radial_basis_centers", radial_basis_centers) self.register_buffer("radial_scale_factor", radial_scale_factor) + @staticmethod + @abstractmethod def calculate_radial_basis_centers( - self, - _min_distance_in_nanometer, - _max_distance_in_nanometer, number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, dtype, ): - # the default approach to calculate radial basis centers - # can be overwritten by subclasses - centers = torch.linspace( - _min_distance_in_nanometer, - _max_distance_in_nanometer, - number_of_radial_basis_functions, - dtype=dtype, - ) - return centers + pass + @staticmethod + @abstractmethod def calculate_radial_scale_factor( - self, - _min_distance_in_nanometer, - _max_distance_in_nanometer, number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, ): - # the default approach to calculate radial scale factors (each of them are scaled by the same value) - # can be overwritten by subclasses - scale_factors = torch.full( - (number_of_radial_basis_functions,), - (_min_distance_in_nanometer - _max_distance_in_nanometer) - / number_of_radial_basis_functions, - ) - scale_factors = scale_factors * -15_000 - return scale_factors + pass def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: diff = distances - self.radial_basis_centers + print(f"{distances.shape=}, {self.radial_basis_centers.shape=}") return diff / self.radial_scale_factor - -class SchnetRadialSymmetryFunction(RadialBasisFunctionWithCenters): +class SchnetRadialBasisFunction(RadialBasisFunctionWithCenters): def __init__( self, number_of_radial_basis_functions: int, max_distance: unit.Quantity, min_distance: unit.Quantity = 0.0 * unit.nanometer, dtype: Optional[torch.dtype] = None, - trainable: bool = False, - radial_basis_function: RadialBasisFunctionCore = GaussianRadialBasisFunctionCore(), + trainable_centers_and_scale_factors: bool = False, ): """RadialSymmetryFunction class. @@ -621,20 +620,38 @@ def __init__( """ super().__init__( + GaussianRadialBasisFunctionCore, number_of_radial_basis_functions, max_distance, min_distance, dtype, - trainable, - radial_basis_function, + trainable_prefactor=False, + trainable_centers_and_scale_factors=trainable_centers_and_scale_factors, ) - self.prefactor = torch.tensor([1.0]) - def calculate_radial_scale_factor( - self, + @staticmethod + def calculate_radial_basis_centers( + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, + ): + # the default approach to calculate radial basis centers + # can be overwritten by subclasses + centers = torch.linspace( _min_distance_in_nanometer, _max_distance_in_nanometer, number_of_radial_basis_functions, + dtype=dtype, + ) + return centers + + @staticmethod + def calculate_radial_scale_factor( + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, ): scale_factors = torch.linspace( _min_distance_in_nanometer, @@ -645,7 +662,7 @@ def calculate_radial_scale_factor( widths = ( torch.abs(scale_factors[1] - scale_factors[0]) * torch.ones_like(scale_factors) - ).to(self.dtype) + ).to(dtype) scale_factors = 0.5 / torch.square_(widths) return scale_factors @@ -654,12 +671,11 @@ def calculate_radial_scale_factor( class AniRadialSymmetryFunction(RadialBasisFunctionWithCenters): def __init__( self, - number_of_radial_basis_functions: int, + number_of_radial_basis_functions, max_distance: unit.Quantity, min_distance: unit.Quantity = 0.0 * unit.nanometer, dtype: Optional[torch.dtype] = None, - trainable: bool = False, - radial_basis_function: RadialBasisFunctionCore = GaussianRadialBasisFunctionCore(), + trainable_centers_and_scale_factors: bool = False, ): """RadialSymmetryFunction class. @@ -670,22 +686,24 @@ def __init__( """ super().__init__( + GaussianRadialBasisFunctionCore, number_of_radial_basis_functions, max_distance, min_distance, dtype, - trainable, - radial_basis_function, + trainable_prefactor=False, + trainable_centers_and_scale_factors=trainable_centers_and_scale_factors, ) - self.prefactor = torch.tensor([0.25]) + self.prefactor = torch.tensor([0.25], dtype=dtype) + @staticmethod def calculate_radial_basis_centers( - self, - _min_distance_in_nanometer, - _max_distance_in_nanometer, number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, dtype, ): + print(f"{_min_distance_in_nanometer=}, {_max_distance_in_nanometer=}, {number_of_radial_basis_functions=}") centers = torch.linspace( _min_distance_in_nanometer, _max_distance_in_nanometer, @@ -694,27 +712,62 @@ def calculate_radial_basis_centers( )[:-1] return centers + @staticmethod def calculate_radial_scale_factor( - self, - _min_distance_in_nanometer, - _max_distance_in_nanometer, number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, ): # ANI uses a predefined scaling factor scale_factors = torch.full((number_of_radial_basis_functions,), (19.7 * 100)) return scale_factors -class SAKERadialSymmetryFunction(RadialBasisFunctionWithCenters): - def calculate_radial_basis_centers( +class PhysNetRadialBasisFunction(RadialBasisFunction): + + def __init__( self, + number_of_radial_basis_functions: int, + max_distance: unit.Quantity = 1.0 * unit.nanometer, + min_distance: unit.Quantity = 0.0 * unit.nanometer, + dtype: Optional[torch.dtype] = None, + trainable_centers_and_scale_factors: bool = False, + ): + super().__init__(GaussianRadialBasisFunctionCore, trainable_prefactor=False, dtype=dtype) + _max_distance_in_nanometer = max_distance.to(unit.nanometer).m + _min_distance_in_nanometer = min_distance.to(unit.nanometer).m + radial_basis_centers = self.calculate_radial_basis_centers( + number_of_radial_basis_functions, + _max_distance_in_nanometer, _min_distance_in_nanometer, + dtype, + ) + # calculate scale factors + radial_scale_factor = self.calculate_radial_scale_factor( + number_of_radial_basis_functions, _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype + ) + + if trainable_centers_and_scale_factors: + self.radial_basis_centers = radial_basis_centers + self.radial_scale_factor = radial_scale_factor + else: + self.register_buffer("radial_basis_centers", radial_basis_centers) + self.register_buffer("radial_scale_factor", radial_scale_factor) + + @staticmethod + def calculate_radial_basis_centers( number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, dtype, ): # initialize means and betas according to the default values in PhysNet # https://pubs.acs.org/doi/10.1021/acs.jctc.9b00181 + # NOTE: Unlike RadialBasisFunctionWithCenters, the centers are unitless. start_value = torch.exp( torch.scalar_tensor( @@ -727,81 +780,42 @@ def calculate_radial_basis_centers( ) return centers + @staticmethod def calculate_radial_scale_factor( - self, - _min_distance_in_nanometer, - _max_distance_in_nanometer, number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, ): - start_value = torch.exp( - torch.scalar_tensor( - (-_max_distance_in_nanometer + _min_distance_in_nanometer) * 10 - ) - ) # NOTE: this is defined in Angstrom - radial_scale_factor = torch.tensor( - torch.full( - (number_of_radial_basis_functions,), - (2 / number_of_radial_basis_functions * (1 - start_value)) ** -2, - ) - ) - return radial_scale_factor - - -class SAKERadialBasisFunctionCore(RadialBasisFunctionCore): - - def __init__(self, min_distance): - super().__init__() - self._min_distance_in_nanometer = min_distance.to(unit.nanometer).m - - def compute( - self, - distances: torch.Tensor, - centers: torch.Tensor, - scale_factors: torch.Tensor, - ) -> torch.Tensor: - return torch.exp( - -scale_factors - * ( - torch.exp( - (-distances.unsqueeze(-1) + self._min_distance_in_nanometer) * 10 + # NOTE: Unlike RadialBasisFunctionWithCenters, the scale factors are unitless. + radial_scale_factor = torch.full( + (number_of_radial_basis_functions,), + number_of_radial_basis_functions + / ( + 2 + * ( + 1 + - math.exp( + 10 * (-_max_distance_in_nanometer + _min_distance_in_nanometer) ) - - centers - ) - ** 2 - ) - - -class PhysNetRadialSymmetryFunction(SAKERadialSymmetryFunction): - - def __init__( - self, - number_of_radial_basis_functions: int, - max_distance: unit.Quantity, - min_distance: unit.Quantity = 0.0 * unit.nanometer, - dtype: Optional[torch.dtype] = None, - trainable: bool = False, - radial_basis_function: Optional[SAKERadialBasisFunctionCore] = None, - ): - """RadialSymmetryFunction class. - - Initializes and contains the logic for computing radial symmetry functions. + ) + ), + dtype=dtype, + ) # NOTE: radial_square_factor here is the square root of beta in the PhysNet paper + return radial_scale_factor - Parameters - --------- + def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: + """ + NOTE: In PhysNet, the input to exp is in Angstroms. In modelforge, distances are in nanometer. Thus, we multiply + 1/Angstrom, which is equivalent to 10/nanometer, to make the input to exp unitless. """ - # Create the radial_basis_function if not provided - if radial_basis_function is None: - radial_basis_function = SAKERadialBasisFunctionCore(min_distance) - super().__init__( - number_of_radial_basis_functions, - max_distance, - min_distance, - dtype, - trainable, - radial_basis_function, + print(f"{self.radial_scale_factor.shape=}") + print(f"{distances.shape=}") + print(f"{self.radial_basis_centers.shape=}") + return self.radial_scale_factor * ( + torch.exp(-distances * 10) - self.radial_basis_centers ) - self.prefactor = torch.tensor([1.0]) def pair_list( diff --git a/modelforge/tests/test_nn.py b/modelforge/tests/test_nn.py index 8e88dae6..c0399488 100644 --- a/modelforge/tests/test_nn.py +++ b/modelforge/tests/test_nn.py @@ -1,12 +1,12 @@ def test_radial_symmetry_function(): - from modelforge.potential.utils import SchnetRadialSymmetryFunction, CosineCutoff + from modelforge.potential.utils import SchnetRadialBasisFunction, CosineCutoff import torch from openff.units import unit # set cutoff and radial symmetry function cutoff = CosineCutoff(cutoff=unit.Quantity(5.0, unit.angstrom)) - rbf_expension = SchnetRadialSymmetryFunction( + rbf_expension = SchnetRadialBasisFunction( number_of_radial_basis_functions=18, max_distance=unit.Quantity(5.0, unit.angstrom), ) diff --git a/modelforge/tests/test_sake.py b/modelforge/tests/test_sake.py index 341ca2df..5ac8a3c3 100644 --- a/modelforge/tests/test_sake.py +++ b/modelforge/tests/test_sake.py @@ -240,7 +240,7 @@ def make_equivalent_pairlist_mask(key, nr_atoms, nr_pairs, include_self_pairs): def test_radial_symmetry_function_against_reference(): from modelforge.potential.utils import ( - SAKERadialSymmetryFunction, + PhysNetRadialBasisFunction, SAKERadialBasisFunctionCore, ) from sake.utils import ExpNormalSmearing as RefExpNormalSmearing @@ -252,7 +252,7 @@ def test_radial_symmetry_function_against_reference(): mf_unit = unit.nanometer ref_unit = unit.nanometer - radial_symmetry_function_module = SAKERadialSymmetryFunction( + radial_symmetry_function_module = PhysNetRadialBasisFunction( number_of_radial_basis_functions=number_of_radial_basis_functions, max_distance=cutoff_upper, min_distance=cutoff_lower, diff --git a/modelforge/tests/test_spk.py b/modelforge/tests/test_spk.py index 12c1e937..1a7c6116 100644 --- a/modelforge/tests/test_spk.py +++ b/modelforge/tests/test_spk.py @@ -3,7 +3,7 @@ def test_compare_radial_symmetry_features(): # compare schnetpack RadialSymmetryFunction with modelforge RadialSymmetryFunction - from modelforge.potential.utils import SchnetRadialSymmetryFunction + from modelforge.potential.utils import SchnetRadialBasisFunction from schnetpack.nn import GaussianRBF as schnetpackGaussianRBF from openff.units import unit @@ -16,7 +16,7 @@ def test_compare_radial_symmetry_features(): cutoff=cutoff.to(unit.angstrom).m, start=start.to(unit.angstrom).m, ) - radial_symmetry_function_module = SchnetRadialSymmetryFunction( + radial_symmetry_function_module = SchnetRadialBasisFunction( number_of_radial_basis_functions=number_of_gaussians, max_distance=cutoff, min_distance=start, From e4af3717a0e69901cf5cb8d8b6ac72826ae66453 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Thu, 27 Jun 2024 17:09:29 -0700 Subject: [PATCH 12/78] Fix SchNet RBF --- modelforge/potential/__init__.py | 1 - modelforge/potential/models.py | 1 - modelforge/potential/utils.py | 9 +-------- modelforge/tests/conftest.py | 3 --- 4 files changed, 1 insertion(+), 13 deletions(-) diff --git a/modelforge/potential/__init__.py b/modelforge/potential/__init__.py index 9a558fc2..757b708e 100644 --- a/modelforge/potential/__init__.py +++ b/modelforge/potential/__init__.py @@ -9,7 +9,6 @@ AngularSymmetryFunction, ) from .processing import FromAtomToMoleculeReduction -from modelforge.train.training import TrainingAdapter from .models import NeuralNetworkPotentialFactory from enum import Enum diff --git a/modelforge/potential/models.py b/modelforge/potential/models.py index 66dbc6cf..92779073 100644 --- a/modelforge/potential/models.py +++ b/modelforge/potential/models.py @@ -758,7 +758,6 @@ def forward(self, data: NNPInput): self.input_preparation._input_checks(data) # prepare the input for the forward pass pairlist_output = self.input_preparation.prepare_inputs(data) - print(f"{type(self)}: {pairlist_output.d_ij.shape=}") return self.core_module(data, pairlist_output) diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index c49c8ed0..d209f600 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -494,7 +494,6 @@ def forward(self, distances: torch.Tensor) -> torch.Tensor: Output of radial basis functions. """ nondimensionalized_distances = self.nondimensionalize_distances(distances) - print(f"{nondimensionalized_distances.size()=}: {type(self)}") return self.prefactor * self.radial_basis_function.compute( nondimensionalized_distances ) @@ -549,7 +548,6 @@ def __init__( def initialize_parameters(self): # convert to nanometer _max_distance_in_nanometer = self.max_distance.to(unit.nanometer).m - print(f"{_max_distance_in_nanometer=}") _min_distance_in_nanometer = self.min_distance.to(unit.nanometer).m # calculate radial basis centers @@ -597,7 +595,6 @@ def calculate_radial_scale_factor( def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: diff = distances - self.radial_basis_centers - print(f"{distances.shape=}, {self.radial_basis_centers.shape=}") return diff / self.radial_scale_factor @@ -663,7 +660,7 @@ def calculate_radial_scale_factor( * torch.ones_like(scale_factors) ).to(dtype) - scale_factors = 0.5 / torch.square_(widths) + scale_factors = math.sqrt(2) * widths return scale_factors @@ -702,7 +699,6 @@ def calculate_radial_basis_centers( _min_distance_in_nanometer, dtype, ): - print(f"{_min_distance_in_nanometer=}, {_max_distance_in_nanometer=}, {number_of_radial_basis_functions=}") centers = torch.linspace( _min_distance_in_nanometer, _max_distance_in_nanometer, @@ -809,9 +805,6 @@ def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: 1/Angstrom, which is equivalent to 10/nanometer, to make the input to exp unitless. """ - print(f"{self.radial_scale_factor.shape=}") - print(f"{distances.shape=}") - print(f"{self.radial_basis_centers.shape=}") return self.radial_scale_factor * ( torch.exp(-distances * 10) - self.radial_basis_centers ) diff --git a/modelforge/tests/conftest.py b/modelforge/tests/conftest.py index a543b2ba..2075e901 100644 --- a/modelforge/tests/conftest.py +++ b/modelforge/tests/conftest.py @@ -28,9 +28,6 @@ def pytest_collection_modifyitems(config, items): item.add_marker(skip_data_download) -from modelforge.potential.utils import BatchData - - # datamodule fixture @pytest.fixture def datamodule_factory(): From 6398cf868a6a0c2b3fda3f5c205d719ac0729baf Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Thu, 27 Jun 2024 22:02:22 -0700 Subject: [PATCH 13/78] Fix PhysNet and ANI RBF --- modelforge/potential/ani.py | 4 ++-- modelforge/potential/physnet.py | 4 ++-- modelforge/potential/utils.py | 40 +++++++++++++------------------- modelforge/tests/test_ani.py | 8 +++---- modelforge/tests/test_physnet.py | 7 +++--- 5 files changed, 28 insertions(+), 35 deletions(-) diff --git a/modelforge/potential/ani.py b/modelforge/potential/ani.py index fc770ae8..432bd09f 100644 --- a/modelforge/potential/ani.py +++ b/modelforge/potential/ani.py @@ -140,9 +140,9 @@ def _setup_radial_symmetry_functions( min_distance: unit.Quantity, number_of_radial_basis_functions: int, ): - from .utils import AniRadialSymmetryFunction + from .utils import AniRadialBasisFunction - radial_symmetry_function = AniRadialSymmetryFunction( + radial_symmetry_function = AniRadialBasisFunction( number_of_radial_basis_functions, max_distance, min_distance, diff --git a/modelforge/potential/physnet.py b/modelforge/potential/physnet.py index 7aca970d..03fb7abf 100644 --- a/modelforge/potential/physnet.py +++ b/modelforge/potential/physnet.py @@ -78,8 +78,8 @@ def __init__( ---------- cutoff : openff.units.unit.Quantity, default=5*unit.angstrom The cutoff distance for interactions. - number_of_gaussians : int, default=16 - Number of Gaussian functions to use in the radial basis function. + number_of_radial_basis_functions : int, default=16 + Number of radial basis functions to use. """ super().__init__() diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index d209f600..6488d6a7 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -581,6 +581,9 @@ def calculate_radial_basis_centers( _min_distance_in_nanometer, dtype, ): + """ + NOTE: centers have units of nanometers + """ pass @staticmethod @@ -591,6 +594,9 @@ def calculate_radial_scale_factor( _min_distance_in_nanometer, dtype, ): + """ + NOTE: radial scale factors have units of nanometers + """ pass def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: @@ -632,15 +638,12 @@ def calculate_radial_basis_centers( _min_distance_in_nanometer, dtype, ): - # the default approach to calculate radial basis centers - # can be overwritten by subclasses - centers = torch.linspace( + return torch.linspace( _min_distance_in_nanometer, _max_distance_in_nanometer, number_of_radial_basis_functions, dtype=dtype, ) - return centers @staticmethod def calculate_radial_scale_factor( @@ -664,7 +667,7 @@ def calculate_radial_scale_factor( return scale_factors -class AniRadialSymmetryFunction(RadialBasisFunctionWithCenters): +class AniRadialBasisFunction(RadialBasisFunctionWithCenters): def __init__( self, number_of_radial_basis_functions, @@ -715,7 +718,7 @@ def calculate_radial_scale_factor( dtype, ): # ANI uses a predefined scaling factor - scale_factors = torch.full((number_of_radial_basis_functions,), (19.7 * 100)) + scale_factors = torch.full((number_of_radial_basis_functions,), (19.7 * 100) ** -0.5) return scale_factors @@ -724,7 +727,7 @@ class PhysNetRadialBasisFunction(RadialBasisFunction): def __init__( self, number_of_radial_basis_functions: int, - max_distance: unit.Quantity = 1.0 * unit.nanometer, + max_distance: unit.Quantity, min_distance: unit.Quantity = 0.0 * unit.nanometer, dtype: Optional[torch.dtype] = None, trainable_centers_and_scale_factors: bool = False, @@ -769,7 +772,7 @@ def calculate_radial_basis_centers( (-_max_distance_in_nanometer + _min_distance_in_nanometer) * 10, dtype=dtype, ) - ) # NOTE: this is defined in Angstrom + ) # NOTE: there is an implicit multiplication by 1/Angstrom within the exp, so we multiply by 10/nanometer. centers = torch.linspace( start_value, 1, number_of_radial_basis_functions, dtype=dtype ) @@ -783,21 +786,12 @@ def calculate_radial_scale_factor( dtype, ): # NOTE: Unlike RadialBasisFunctionWithCenters, the scale factors are unitless. - radial_scale_factor = torch.full( + return torch.full( (number_of_radial_basis_functions,), - number_of_radial_basis_functions - / ( - 2 - * ( - 1 - - math.exp( - 10 * (-_max_distance_in_nanometer + _min_distance_in_nanometer) - ) - ) - ), + (2 * (1 - math.exp(10 * (-_max_distance_in_nanometer + _min_distance_in_nanometer)))) / + number_of_radial_basis_functions, dtype=dtype, - ) # NOTE: radial_square_factor here is the square root of beta in the PhysNet paper - return radial_scale_factor + ) # NOTE: radial_square_factor here is the reciprocal of the square root of beta in the PhysNet paper def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: """ @@ -805,9 +799,7 @@ def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: 1/Angstrom, which is equivalent to 10/nanometer, to make the input to exp unitless. """ - return self.radial_scale_factor * ( - torch.exp(-distances * 10) - self.radial_basis_centers - ) + return (torch.exp(-distances * 10).unsqueeze(-1) - self.radial_basis_centers) / self.radial_scale_factor def pair_list( diff --git a/modelforge/tests/test_ani.py b/modelforge/tests/test_ani.py index 9a48c703..842ce0a8 100644 --- a/modelforge/tests/test_ani.py +++ b/modelforge/tests/test_ani.py @@ -126,7 +126,7 @@ def test_compare_radial_symmetry_features(): # Compare the ANI radial symmetry function # to the output of the modelforge radial symmetry function import torch - from modelforge.potential.utils import AniRadialSymmetryFunction, CosineCutoff + from modelforge.potential.utils import AniRadialBasisFunction, CosineCutoff from openff.units import unit # generate a random list of distances, all < 5 @@ -140,7 +140,7 @@ def test_compare_radial_symmetry_features(): ShfR = torch.linspace(radial_start, radial_cutoff, radial_dist_divisions + 1)[:-1] # NOTE: we pass in Angstrom to ANI and in nanometer to mf - rsf = AniRadialSymmetryFunction( + rsf = AniRadialBasisFunction( number_of_radial_basis_functions=radial_dist_divisions, max_distance=radial_cutoff * unit.angstrom, min_distance=radial_start * unit.angstrom, @@ -158,7 +158,7 @@ def test_compare_radial_symmetry_features(): def test_radial_with_diagonal_batching(setup_two_methanes): import torch - from modelforge.potential.utils import AniRadialSymmetryFunction, CosineCutoff + from modelforge.potential.utils import AniRadialBasisFunction, CosineCutoff from openff.units import unit from modelforge.potential.models import Pairlist from torchani.aev import neighbor_pairs_nopbc @@ -193,7 +193,7 @@ def test_radial_with_diagonal_batching(setup_two_methanes): # ------------ Modelforge calculation ----------# device = torch.device("cpu") - radial_symmetry_function = AniRadialSymmetryFunction( + radial_symmetry_function = AniRadialBasisFunction( radial_dist_divisions, radial_cutoff * unit.angstrom, radial_start * unit.angstrom, diff --git a/modelforge/tests/test_physnet.py b/modelforge/tests/test_physnet.py index 0ce0ad0f..a9002382 100644 --- a/modelforge/tests/test_physnet.py +++ b/modelforge/tests/test_physnet.py @@ -69,9 +69,9 @@ def test_rbf(): # RBF comparision ############################# # Initialize the rbf class - from modelforge.potential.utils import PhysNetRadialSymmetryFunction + from modelforge.potential.utils import PhysNetRadialBasisFunction - mf_rbf = PhysNetRadialSymmetryFunction( + mf_rbf = PhysNetRadialBasisFunction( number_of_radial_basis_functions, max_distance=_max_distance_in_nanometer * unit.nanometer, ) @@ -92,7 +92,8 @@ def softplus_inverse(x): # Modelforge implementation mf_widths_np = mf_rbf.get_buffer("radial_scale_factor").numpy() - assert np.allclose(pn_widths_np, mf_widths_np) + assert np.allclose(pn_widths_np ** -0.5, mf_widths_np) # we redefine the scale factor such that we can apply the + # Gaussian RBF # center_position ################# From 7e375e8229961ab34cc82a2166a3554a2ec74514 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Fri, 28 Jun 2024 12:07:49 -0700 Subject: [PATCH 14/78] Trying to fix SAKE but failing --- modelforge/potential/sake.py | 5 ++- modelforge/tests/test_sake.py | 81 ++++++++++++++++++++++++++++------- 2 files changed, 70 insertions(+), 16 deletions(-) diff --git a/modelforge/potential/sake.py b/modelforge/potential/sake.py index db3e0ebd..52abdb4e 100644 --- a/modelforge/potential/sake.py +++ b/modelforge/potential/sake.py @@ -321,9 +321,12 @@ def update_edge(self, h_i_by_pair, h_j_by_pair, d_ij): Intermediate edge features. Shape [nr_pairs, nr_edge_basis]. """ h_ij_cat = torch.cat([h_i_by_pair, h_j_by_pair], dim=-1) - h_ij_filtered = self.radial_symmetry_function_module(d_ij.unsqueeze(-1)) * self.edge_mlp_in( + print(f"{self.radial_symmetry_function_module(d_ij.unsqueeze(-1)).shape=}") + print(f"{self.edge_mlp_in(h_ij_cat).shape=}") + h_ij_filtered = self.radial_symmetry_function_module(d_ij.unsqueeze(-1)).squeeze(-2) * self.edge_mlp_in( h_ij_cat ) + print(f"{h_ij_filtered.shape=}, {h_ij_cat.shape=}, {d_ij.shape=}") return self.edge_mlp_out( torch.cat([h_ij_cat, h_ij_filtered, d_ij.unsqueeze(-1)], dim=-1) ) diff --git a/modelforge/tests/test_sake.py b/modelforge/tests/test_sake.py index 008dbeba..00cd37b2 100644 --- a/modelforge/tests/test_sake.py +++ b/modelforge/tests/test_sake.py @@ -234,52 +234,103 @@ def make_equivalent_pairlist_mask(key, nr_atoms, nr_pairs, include_self_pairs): def test_radial_symmetry_function_against_reference(): from modelforge.potential.utils import ( PhysNetRadialBasisFunction, - SAKERadialBasisFunctionCore, ) from sake.utils import ExpNormalSmearing as RefExpNormalSmearing - nr_atoms = 13 - number_of_radial_basis_functions = 11 - cutoff_upper = 6.0 * unit.bohr - cutoff_lower = 2.0 * unit.bohr - mf_unit = unit.nanometer - ref_unit = unit.nanometer + nr_atoms = 1 + number_of_radial_basis_functions = 10 + cutoff_upper = 6.0 * unit.nanometer + cutoff_lower = 2.0 * unit.nanometer radial_symmetry_function_module = PhysNetRadialBasisFunction( number_of_radial_basis_functions=number_of_radial_basis_functions, max_distance=cutoff_upper, min_distance=cutoff_lower, dtype=torch.float32, - trainable=False, - radial_basis_function=SAKERadialBasisFunctionCore(cutoff_lower), ) ref_radial_basis_module = RefExpNormalSmearing( num_rbf=number_of_radial_basis_functions, - cutoff_upper=cutoff_upper.to(ref_unit).m, - cutoff_lower=cutoff_lower.to(ref_unit).m, + cutoff_upper=cutoff_upper.m, + cutoff_lower=cutoff_lower.m, ) key = jax.random.PRNGKey(1884) # Generate random input data in JAX - d_ij_bohr_mag = jax.random.normal(key, (nr_atoms, nr_atoms, 1)) - d_ij_jax = (d_ij_bohr_mag * unit.bohr).to(ref_unit).m + d_ij_jax = jax.random.uniform(key, (nr_atoms, nr_atoms, 1)) d_ij = torch.from_numpy( - onp.array((d_ij_bohr_mag * unit.bohr).to(mf_unit).m) + onp.array(d_ij_jax) ).reshape(nr_atoms**2) mf_rbf = radial_symmetry_function_module(d_ij) variables = ref_radial_basis_module.init(key, d_ij_jax) + print(f"{variables['params']['means']=}") + print(f"{variables['params']['betas']=}") assert torch.allclose( torch.from_numpy(onp.array(variables["params"]["means"])), radial_symmetry_function_module.radial_basis_centers.detach().T, ) assert torch.allclose( - torch.from_numpy(onp.array(variables["params"]["betas"])), + torch.from_numpy(onp.array(variables["params"]["betas"])) ** -0.5, radial_symmetry_function_module.radial_scale_factor.detach().T, ) ref_rbf = ref_radial_basis_module.apply(variables, d_ij_jax) + print(f"{mf_rbf=}") + print(f"{ref_rbf=}") + + assert torch.allclose( + mf_rbf, + torch.from_numpy(onp.array(ref_rbf)).reshape( + nr_atoms**2, number_of_radial_basis_functions + ), + ) + +def test_rbf_forward(): + from modelforge.potential.utils import ( + PhysNetRadialBasisFunction, + ) + from sake.utils import ExpNormalSmearing as RefExpNormalSmearing + + nr_atoms = 1 + number_of_radial_basis_functions = 1 + cutoff_upper = 6.0 * unit.nanometer + cutoff_lower = 2.0 * unit.nanometer + + radial_symmetry_function_module = PhysNetRadialBasisFunction( + number_of_radial_basis_functions=number_of_radial_basis_functions, + max_distance=cutoff_upper, + min_distance=cutoff_lower, + dtype=torch.float32, + ) + ref_radial_basis_module = RefExpNormalSmearing( + num_rbf=number_of_radial_basis_functions, + cutoff_upper=cutoff_upper.m, + cutoff_lower=cutoff_lower.m, + ) + key = jax.random.PRNGKey(1882) + + # Generate random input data in JAX + d_ij_jax = jnp.full((nr_atoms, nr_atoms, 1), 1) + d_ij = torch.from_numpy( + onp.array(d_ij_jax) + ).reshape(nr_atoms**2) + + variables = ref_radial_basis_module.init(key, d_ij_jax) + + means = 1e-18 + betas = 25 + + variables["params"]["means"] = jnp.full_like(variables["params"]["means"], means) + radial_symmetry_function_module.radial_basis_centers[:] = means + variables["params"]["betas"] = jnp.full_like(variables["params"]["betas"], betas) + radial_symmetry_function_module.radial_scale_factor[:] = betas ** -2 + + mf_rbf = radial_symmetry_function_module(d_ij) + ref_rbf = ref_radial_basis_module.apply(variables, d_ij_jax) + print(f"{d_ij=}") + print(f"{mf_rbf=}") + print(f"{ref_rbf=}") assert torch.allclose( mf_rbf, From e48a3955fc97ab74616561ad6240d0947aeb8f85 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Mon, 1 Jul 2024 22:29:23 -0700 Subject: [PATCH 15/78] Fix SAKE test --- modelforge/potential/utils.py | 50 ++++++++++++++++++----------------- modelforge/tests/test_sake.py | 21 ++++++++------- 2 files changed, 37 insertions(+), 34 deletions(-) diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index 6488d6a7..3954cd48 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -499,10 +499,9 @@ def forward(self, distances: torch.Tensor) -> torch.Tensor: ) -class RadialBasisFunctionWithCenters(RadialBasisFunction): +class GaussianRadialBasisFunctionWithScaling(RadialBasisFunction): def __init__( self, - radial_basis_function: Type[RadialBasisFunctionCore], number_of_radial_basis_functions: int, max_distance: unit.Quantity, min_distance: unit.Quantity = 0.0 * unit.nanometer, @@ -510,7 +509,7 @@ def __init__( trainable_prefactor: bool = False, trainable_centers_and_scale_factors: bool = False, ): - """RadialSymmetryFunction class. + """ Initializes and contains the logic for computing radial symmetry functions. @@ -528,19 +527,17 @@ def __init__( Whether prefactor is trainable trainable_centers_and_scale_factors: bool, default False Whether centers and scale factors are trainable. - radial_basis_function: RadialBasisFunction, default GaussianRadialBasisFunction() Subclasses must implement the forward() method to compute the actual symmetry function output given an input distance matrix. """ - super().__init__(radial_basis_function, dtype, trainable_prefactor) + super().__init__(GaussianRadialBasisFunctionCore, dtype, trainable_prefactor) self.number_of_radial_basis_functions = number_of_radial_basis_functions self.max_distance = max_distance self.min_distance = min_distance self.dtype = dtype self.trainable_centers_and_scale_factors = trainable_centers_and_scale_factors - self.radial_basis_function = radial_basis_function self.initialize_parameters() # The length of radial subaev of a single species self.radial_sublength = self.radial_basis_centers.numel() @@ -600,11 +597,12 @@ def calculate_radial_scale_factor( pass def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: + # Here, self.radial_scale_factor is interpreted as sqrt(2) times the standard deviation of the Gaussian. diff = distances - self.radial_basis_centers return diff / self.radial_scale_factor -class SchnetRadialBasisFunction(RadialBasisFunctionWithCenters): +class SchnetRadialBasisFunction(GaussianRadialBasisFunctionWithScaling): def __init__( self, number_of_radial_basis_functions: int, @@ -667,7 +665,7 @@ def calculate_radial_scale_factor( return scale_factors -class AniRadialBasisFunction(RadialBasisFunctionWithCenters): +class AniRadialBasisFunction(GaussianRadialBasisFunctionWithScaling): def __init__( self, number_of_radial_basis_functions, @@ -733,19 +731,19 @@ def __init__( trainable_centers_and_scale_factors: bool = False, ): super().__init__(GaussianRadialBasisFunctionCore, trainable_prefactor=False, dtype=dtype) - _max_distance_in_nanometer = max_distance.to(unit.nanometer).m - _min_distance_in_nanometer = min_distance.to(unit.nanometer).m + self._max_distance_in_nanometer = max_distance.to(unit.nanometer).m + self._min_distance_in_nanometer = min_distance.to(unit.nanometer).m radial_basis_centers = self.calculate_radial_basis_centers( number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, + self._max_distance_in_nanometer, + self._min_distance_in_nanometer, dtype, ) # calculate scale factors radial_scale_factor = self.calculate_radial_scale_factor( number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, + self._max_distance_in_nanometer, + self._min_distance_in_nanometer, dtype ) @@ -763,8 +761,8 @@ def calculate_radial_basis_centers( _min_distance_in_nanometer, dtype, ): - # initialize means and betas according to the default values in PhysNet - # https://pubs.acs.org/doi/10.1021/acs.jctc.9b00181 + # initialize centers according to the default values in PhysNet https://pubs.acs.org/doi/10.1021/acs.jctc.9b00181 # noqa + # (see mu_k in Figure 2 caption) # NOTE: Unlike RadialBasisFunctionWithCenters, the centers are unitless. start_value = torch.exp( @@ -772,7 +770,8 @@ def calculate_radial_basis_centers( (-_max_distance_in_nanometer + _min_distance_in_nanometer) * 10, dtype=dtype, ) - ) # NOTE: there is an implicit multiplication by 1/Angstrom within the exp, so we multiply by 10/nanometer. + ) # NOTE: the PhysNet paper implicitly multiplies by 1/Angstrom within the exp, so we multiply + # _max_distance_in_nanometers and _min_distance_in_nanometers by 10/nanometer centers = torch.linspace( start_value, 1, number_of_radial_basis_functions, dtype=dtype ) @@ -785,21 +784,24 @@ def calculate_radial_scale_factor( _min_distance_in_nanometer, dtype, ): - # NOTE: Unlike RadialBasisFunctionWithCenters, the scale factors are unitless. + # initialize according to the default values in PhysNet (see beta_k in Eq. 7) + # NOTES: - Unlike RadialBasisFunctionWithCenters, the scale factors are unitless. + # - Each element of radial_square_factor here is the reciprocal of the square root of beta_k in the + # Eq. 7 of the PhysNet paper. This way, it is consistent with the sqrt(2) * standard deviation interpretation + # of radial_scale_factor in GaussianRadialBasisFunctionWithScaling return torch.full( (number_of_radial_basis_functions,), (2 * (1 - math.exp(10 * (-_max_distance_in_nanometer + _min_distance_in_nanometer)))) / number_of_radial_basis_functions, dtype=dtype, - ) # NOTE: radial_square_factor here is the reciprocal of the square root of beta in the PhysNet paper + ) def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: - """ - NOTE: In PhysNet, the input to exp is in Angstroms. In modelforge, distances are in nanometer. Thus, we multiply - 1/Angstrom, which is equivalent to 10/nanometer, to make the input to exp unitless. - """ + # Transformation within the outer exp of PhysNet Eq. 7 + # NOTE: the PhysNet paper implicitly multiplies by 1/Angstrom within the inner exp but distances are in + # nanometers, so we multiply by 10/nanometer - return (torch.exp(-distances * 10).unsqueeze(-1) - self.radial_basis_centers) / self.radial_scale_factor + return (torch.exp((-distances + self._min_distance_in_nanometer) * 10).unsqueeze(-1) - self.radial_basis_centers) / self.radial_scale_factor def pair_list( diff --git a/modelforge/tests/test_sake.py b/modelforge/tests/test_sake.py index 00cd37b2..91cf6c5a 100644 --- a/modelforge/tests/test_sake.py +++ b/modelforge/tests/test_sake.py @@ -10,7 +10,6 @@ import sake as reference_sake from sys import platform - ON_MAC = platform == "darwin" @@ -59,7 +58,7 @@ def test_sake_forward(single_batch_with_batchsize_64): nr_of_mols = methane.atomic_subsystem_indices.unique().shape[0] assert ( - len(energy) == nr_of_mols + len(energy) == nr_of_mols ) # Assuming energy is calculated per sample in the batch @@ -259,7 +258,7 @@ def test_radial_symmetry_function_against_reference(): d_ij_jax = jax.random.uniform(key, (nr_atoms, nr_atoms, 1)) d_ij = torch.from_numpy( onp.array(d_ij_jax) - ).reshape(nr_atoms**2) + ).reshape(nr_atoms ** 2) mf_rbf = radial_symmetry_function_module(d_ij) variables = ref_radial_basis_module.init(key, d_ij_jax) @@ -282,16 +281,19 @@ def test_radial_symmetry_function_against_reference(): assert torch.allclose( mf_rbf, torch.from_numpy(onp.array(ref_rbf)).reshape( - nr_atoms**2, number_of_radial_basis_functions + nr_atoms ** 2, number_of_radial_basis_functions ), ) + def test_rbf_forward(): from modelforge.potential.utils import ( PhysNetRadialBasisFunction, ) from sake.utils import ExpNormalSmearing as RefExpNormalSmearing + jax.numpy.set_printoptions(precision=100) + torch.set_printoptions(precision=100) nr_atoms = 1 number_of_radial_basis_functions = 1 cutoff_upper = 6.0 * unit.nanometer @@ -314,12 +316,12 @@ def test_rbf_forward(): d_ij_jax = jnp.full((nr_atoms, nr_atoms, 1), 1) d_ij = torch.from_numpy( onp.array(d_ij_jax) - ).reshape(nr_atoms**2) + ).reshape(nr_atoms ** 2) variables = ref_radial_basis_module.init(key, d_ij_jax) - means = 1e-18 - betas = 25 + means = 0 + betas = 1 variables["params"]["means"] = jnp.full_like(variables["params"]["means"], means) radial_symmetry_function_module.radial_basis_centers[:] = means @@ -335,7 +337,7 @@ def test_rbf_forward(): assert torch.allclose( mf_rbf, torch.from_numpy(onp.array(ref_rbf)).reshape( - nr_atoms**2, number_of_radial_basis_functions + nr_atoms ** 2, number_of_radial_basis_functions ), ) @@ -344,7 +346,6 @@ def test_rbf_forward(): @pytest.mark.parametrize("include_self_pairs", [True, False]) @pytest.mark.parametrize("v_is_none", [True, False]) def test_sake_layer_against_reference(include_self_pairs, v_is_none): - nr_atoms = 13 out_features = 11 hidden_features = 7 @@ -547,7 +548,7 @@ def test_sake_model_against_reference(single_batch_with_batchsize_1): if layer_name.startswith("d") ) for (layer_name, layer), mf_sake_block in zip( - layers, mf_sake.core_module.interaction_modules.children() + layers, mf_sake.core_module.interaction_modules.children() ): layer["edge_model"]["kernel"]["betas"] = ( mf_sake_block.radial_symmetry_function_module.radial_scale_factor.detach() From fdfa4d5daadad3997e1b7199d2fca0a08d0c8d2b Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Mon, 1 Jul 2024 22:30:00 -0700 Subject: [PATCH 16/78] Remove unnecessary test --- modelforge/tests/test_sake.py | 56 ----------------------------------- 1 file changed, 56 deletions(-) diff --git a/modelforge/tests/test_sake.py b/modelforge/tests/test_sake.py index 91cf6c5a..199f986c 100644 --- a/modelforge/tests/test_sake.py +++ b/modelforge/tests/test_sake.py @@ -286,62 +286,6 @@ def test_radial_symmetry_function_against_reference(): ) -def test_rbf_forward(): - from modelforge.potential.utils import ( - PhysNetRadialBasisFunction, - ) - from sake.utils import ExpNormalSmearing as RefExpNormalSmearing - - jax.numpy.set_printoptions(precision=100) - torch.set_printoptions(precision=100) - nr_atoms = 1 - number_of_radial_basis_functions = 1 - cutoff_upper = 6.0 * unit.nanometer - cutoff_lower = 2.0 * unit.nanometer - - radial_symmetry_function_module = PhysNetRadialBasisFunction( - number_of_radial_basis_functions=number_of_radial_basis_functions, - max_distance=cutoff_upper, - min_distance=cutoff_lower, - dtype=torch.float32, - ) - ref_radial_basis_module = RefExpNormalSmearing( - num_rbf=number_of_radial_basis_functions, - cutoff_upper=cutoff_upper.m, - cutoff_lower=cutoff_lower.m, - ) - key = jax.random.PRNGKey(1882) - - # Generate random input data in JAX - d_ij_jax = jnp.full((nr_atoms, nr_atoms, 1), 1) - d_ij = torch.from_numpy( - onp.array(d_ij_jax) - ).reshape(nr_atoms ** 2) - - variables = ref_radial_basis_module.init(key, d_ij_jax) - - means = 0 - betas = 1 - - variables["params"]["means"] = jnp.full_like(variables["params"]["means"], means) - radial_symmetry_function_module.radial_basis_centers[:] = means - variables["params"]["betas"] = jnp.full_like(variables["params"]["betas"], betas) - radial_symmetry_function_module.radial_scale_factor[:] = betas ** -2 - - mf_rbf = radial_symmetry_function_module(d_ij) - ref_rbf = ref_radial_basis_module.apply(variables, d_ij_jax) - print(f"{d_ij=}") - print(f"{mf_rbf=}") - print(f"{ref_rbf=}") - - assert torch.allclose( - mf_rbf, - torch.from_numpy(onp.array(ref_rbf)).reshape( - nr_atoms ** 2, number_of_radial_basis_functions - ), - ) - - @pytest.mark.skipif(ON_MAC, reason="Test fails on macOS") @pytest.mark.parametrize("include_self_pairs", [True, False]) @pytest.mark.parametrize("v_is_none", [True, False]) From c4b24b4e0ff1af5cc4f12e2ed2261a5b61f2de16 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Mon, 1 Jul 2024 22:40:40 -0700 Subject: [PATCH 17/78] Refactor prefactor --- modelforge/potential/utils.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index 3954cd48..7ca5b5b0 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -455,13 +455,14 @@ def __init__( self, radial_basis_function: Type[RadialBasisFunctionCore], dtype, + prefactor: float = 1.0, trainable_prefactor: bool = False, ): super().__init__() if trainable_prefactor: - self.prefactor = nn.Parameter(torch.tensor([1.0], dtype=dtype)) + self.prefactor = nn.Parameter(torch.tensor([prefactor], dtype=dtype)) else: - self.register_buffer("prefactor", torch.tensor([1.0], dtype=dtype)) + self.register_buffer("prefactor", torch.tensor([prefactor], dtype=dtype)) self.radial_basis_function = radial_basis_function @abstractmethod @@ -506,6 +507,7 @@ def __init__( max_distance: unit.Quantity, min_distance: unit.Quantity = 0.0 * unit.nanometer, dtype: Optional[torch.dtype] = None, + prefactor: float = 1.0, trainable_prefactor: bool = False, trainable_centers_and_scale_factors: bool = False, ): @@ -523,6 +525,8 @@ def __init__( Minimum distance to consider. dtype: Data type for computations. + prefactor: + Scalar factor by which to multiply output of radial basis functions. trainable_prefactor: bool, default False Whether prefactor is trainable trainable_centers_and_scale_factors: bool, default False @@ -532,7 +536,7 @@ def __init__( symmetry function output given an input distance matrix. """ - super().__init__(GaussianRadialBasisFunctionCore, dtype, trainable_prefactor) + super().__init__(GaussianRadialBasisFunctionCore, dtype, prefactor, trainable_prefactor) self.number_of_radial_basis_functions = number_of_radial_basis_functions self.max_distance = max_distance self.min_distance = min_distance @@ -620,7 +624,6 @@ def __init__( """ super().__init__( - GaussianRadialBasisFunctionCore, number_of_radial_basis_functions, max_distance, min_distance, @@ -683,15 +686,14 @@ def __init__( """ super().__init__( - GaussianRadialBasisFunctionCore, number_of_radial_basis_functions, max_distance, min_distance, dtype, + prefactor=0.25, trainable_prefactor=False, trainable_centers_and_scale_factors=trainable_centers_and_scale_factors, ) - self.prefactor = torch.tensor([0.25], dtype=dtype) @staticmethod def calculate_radial_basis_centers( From 25a5ac9bfa8c207fefbacf2a113dd13255497357 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Mon, 1 Jul 2024 22:48:39 -0700 Subject: [PATCH 18/78] Update comment for nondimensionalization --- modelforge/potential/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index 7ca5b5b0..a4f131e7 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -482,7 +482,10 @@ def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: def forward(self, distances: torch.Tensor) -> torch.Tensor: """ - Applies nondimensionalization transformations on the distances and passes the result to RadialBasisFunctionCore. + The input distances have implicit units of nanometers by the convention of modelforge. This function applies + nondimensionalization transformations on the distances and passes the dimensionless result to + RadialBasisFunctionCore. There can be several nondimsionalization transformations, corresponding to each element + along the number_of_radial_basis_functions axis in the output. Parameters --------- From c0280fb622b19d5c064d949ce1281753b180e268 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Mon, 1 Jul 2024 23:55:22 -0700 Subject: [PATCH 19/78] Fix SAKE and ANI bugs --- modelforge/potential/ani.py | 9 +++----- modelforge/potential/utils.py | 33 +++++++++++--------------- modelforge/tests/test_sake.py | 6 +---- modelforge/tests/test_utils.py | 42 +++++++++++++++++++++++++++++++--- 4 files changed, 56 insertions(+), 34 deletions(-) diff --git a/modelforge/potential/ani.py b/modelforge/potential/ani.py index 432bd09f..9490d074 100644 --- a/modelforge/potential/ani.py +++ b/modelforge/potential/ani.py @@ -251,7 +251,7 @@ def _postprocess_radial_aev( radial_feature_vector = radial_feature_vector.squeeze(1) number_of_atoms = data.number_of_atoms - radial_sublength = self.radial_symmetry_functions.radial_sublength + radial_sublength = self.radial_symmetry_functions.number_of_radial_basis_functions radial_length = radial_sublength * self.nr_of_supported_elements radial_aev = radial_feature_vector.new_zeros( @@ -445,7 +445,7 @@ def __init__( angle_sections: int = 4, ) -> None: """ - Initialize the ANi NNP architeture. + Initialize the ANi NNP architecture. Parameters ---------- @@ -467,10 +467,7 @@ def __init__( angle_sections, ) # The length of radial aev - self.radial_length = ( - self.num_species - * self.ani_representation_module.radial_symmetry_functions.radial_sublength - ) + self.radial_length = self.num_species * number_of_radial_basis_functions # The length of angular aev self.angular_length = ( (self.num_species * (self.num_species + 1)) diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index a4f131e7..c3376b84 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -504,6 +504,9 @@ def forward(self, distances: torch.Tensor) -> torch.Tensor: class GaussianRadialBasisFunctionWithScaling(RadialBasisFunction): + """ + Shifts inputs by a set of centers and scales by a set of scale factors before passing into the standard Gaussian. + """ def __init__( self, number_of_radial_basis_functions: int, @@ -515,9 +518,6 @@ def __init__( trainable_centers_and_scale_factors: bool = False, ): """ - - Initializes and contains the logic for computing radial symmetry functions. - Parameters --------- number_of_radial_basis_functions: int @@ -534,25 +534,15 @@ def __init__( Whether prefactor is trainable trainable_centers_and_scale_factors: bool, default False Whether centers and scale factors are trainable. - - Subclasses must implement the forward() method to compute the actual - symmetry function output given an input distance matrix. """ super().__init__(GaussianRadialBasisFunctionCore, dtype, prefactor, trainable_prefactor) self.number_of_radial_basis_functions = number_of_radial_basis_functions - self.max_distance = max_distance - self.min_distance = min_distance self.dtype = dtype self.trainable_centers_and_scale_factors = trainable_centers_and_scale_factors - self.initialize_parameters() - # The length of radial subaev of a single species - self.radial_sublength = self.radial_basis_centers.numel() - - def initialize_parameters(self): # convert to nanometer - _max_distance_in_nanometer = self.max_distance.to(unit.nanometer).m - _min_distance_in_nanometer = self.min_distance.to(unit.nanometer).m + _max_distance_in_nanometer = max_distance.to(unit.nanometer).m + _min_distance_in_nanometer = min_distance.to(unit.nanometer).m # calculate radial basis centers radial_basis_centers = self.calculate_radial_basis_centers( @@ -766,8 +756,8 @@ def calculate_radial_basis_centers( _min_distance_in_nanometer, dtype, ): - # initialize centers according to the default values in PhysNet https://pubs.acs.org/doi/10.1021/acs.jctc.9b00181 # noqa - # (see mu_k in Figure 2 caption) + # initialize centers according to the default values in PhysNet + # (see mu_k in Figure 2 caption of https://pubs.acs.org/doi/10.1021/acs.jctc.9b00181) # NOTE: Unlike RadialBasisFunctionWithCenters, the centers are unitless. start_value = torch.exp( @@ -789,8 +779,9 @@ def calculate_radial_scale_factor( _min_distance_in_nanometer, dtype, ): - # initialize according to the default values in PhysNet (see beta_k in Eq. 7) - # NOTES: - Unlike RadialBasisFunctionWithCenters, the scale factors are unitless. + # initialize according to the default values in PhysNet (see beta_k in Figure 2 caption) + # NOTES: + # - Unlike RadialBasisFunctionWithCenters, the scale factors are unitless. # - Each element of radial_square_factor here is the reciprocal of the square root of beta_k in the # Eq. 7 of the PhysNet paper. This way, it is consistent with the sqrt(2) * standard deviation interpretation # of radial_scale_factor in GaussianRadialBasisFunctionWithScaling @@ -806,7 +797,9 @@ def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: # NOTE: the PhysNet paper implicitly multiplies by 1/Angstrom within the inner exp but distances are in # nanometers, so we multiply by 10/nanometer - return (torch.exp((-distances + self._min_distance_in_nanometer) * 10).unsqueeze(-1) - self.radial_basis_centers) / self.radial_scale_factor + return ((torch.exp( + (-distances + self._min_distance_in_nanometer) * 10).unsqueeze(-1) - self.radial_basis_centers) + / self.radial_scale_factor) def pair_list( diff --git a/modelforge/tests/test_sake.py b/modelforge/tests/test_sake.py index 199f986c..6db56df5 100644 --- a/modelforge/tests/test_sake.py +++ b/modelforge/tests/test_sake.py @@ -263,8 +263,6 @@ def test_radial_symmetry_function_against_reference(): mf_rbf = radial_symmetry_function_module(d_ij) variables = ref_radial_basis_module.init(key, d_ij_jax) - print(f"{variables['params']['means']=}") - print(f"{variables['params']['betas']=}") assert torch.allclose( torch.from_numpy(onp.array(variables["params"]["means"])), radial_symmetry_function_module.radial_basis_centers.detach().T, @@ -275,8 +273,6 @@ def test_radial_symmetry_function_against_reference(): ) ref_rbf = ref_radial_basis_module.apply(variables, d_ij_jax) - print(f"{mf_rbf=}") - print(f"{ref_rbf=}") assert torch.allclose( mf_rbf, @@ -326,7 +322,7 @@ def test_sake_layer_against_reference(include_self_pairs, v_is_none): layer = variables["params"] assert torch.allclose( - torch.from_numpy(onp.array(layer["edge_model"]["kernel"]["betas"])), + torch.from_numpy(onp.array(layer["edge_model"]["kernel"]["betas"]) ** -0.5), mf_sake_block.radial_symmetry_function_module.radial_scale_factor.detach().T, ) assert torch.allclose( diff --git a/modelforge/tests/test_utils.py b/modelforge/tests/test_utils.py index 9b5b4eb9..b3450e62 100644 --- a/modelforge/tests/test_utils.py +++ b/modelforge/tests/test_utils.py @@ -2,7 +2,6 @@ import torch import pytest -from modelforge.potential.utils import CosineCutoff, RadialBasisFunction def test_dense_layer(): @@ -73,6 +72,7 @@ def test_cosine_cutoff(): """ Test the cosine cutoff implementation. """ + from modelforge.potential.utils import CosineCutoff # Define inputs x = torch.Tensor([1, 2, 3]) y = torch.Tensor([4, 5, 6]) @@ -96,6 +96,7 @@ def test_cosine_cutoff(): def test_cosine_cutoff_module(): # Test CosineCutoff module + from modelforge.potential.utils import CosineCutoff from openff.units import unit # test the cutoff on this distance vector (NOTE: it is in angstrom) @@ -116,16 +117,51 @@ def test_radial_symmetry_function_implementation(): """ Test the Radial Symmetry function implementation. """ - from modelforge.potential.utils import RadialBasisFunction, CosineCutoff import torch from openff.units import unit import numpy as np + from modelforge.potential.utils import CosineCutoff, GaussianRadialBasisFunctionWithScaling cutoff_module = CosineCutoff(cutoff=unit.Quantity(5.0, unit.angstrom)) - RSF = RadialBasisFunction( + + class RadialSymmetryFunctionTest(GaussianRadialBasisFunctionWithScaling): + @staticmethod + def calculate_radial_basis_centers( + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, + ): + centers = torch.linspace( + _min_distance_in_nanometer, + _max_distance_in_nanometer, + number_of_radial_basis_functions, + dtype=dtype, + ) + return centers + + @staticmethod + def calculate_radial_scale_factor( + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype + ): + scale_factors = torch.full( + (number_of_radial_basis_functions,), + (_min_distance_in_nanometer - _max_distance_in_nanometer) + / number_of_radial_basis_functions, + ) + scale_factors = scale_factors * -15_000 + return scale_factors + + + RSF = RadialSymmetryFunctionTest( number_of_radial_basis_functions=18, max_distance=unit.Quantity(5.0, unit.angstrom), ) + print(f"{RSF.radial_basis_centers=}") + print(f"{RSF.radial_scale_factor=}") # test a single distance d_ij = torch.tensor([[0.2]]) radial_expension = RSF(d_ij) From a3bfc0e277295cb23dc36944e72ca92b5e57595a Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Tue, 2 Jul 2024 00:01:26 -0700 Subject: [PATCH 20/78] Fix test radial symmetry function --- modelforge/tests/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelforge/tests/test_utils.py b/modelforge/tests/test_utils.py index b3450e62..83d266c3 100644 --- a/modelforge/tests/test_utils.py +++ b/modelforge/tests/test_utils.py @@ -152,7 +152,7 @@ def calculate_radial_scale_factor( (_min_distance_in_nanometer - _max_distance_in_nanometer) / number_of_radial_basis_functions, ) - scale_factors = scale_factors * -15_000 + scale_factors = (scale_factors * -15_000) ** -0.5 return scale_factors From 6b9f171a6425e98b4e30824b48766b19356c77b4 Mon Sep 17 00:00:00 2001 From: Marcus Wieder <31651017+wiederm@users.noreply.github.com> Date: Wed, 3 Jul 2024 04:26:17 -0400 Subject: [PATCH 21/78] Update utils.py Small changes to the dtype that is passed to the RBF. I think it is best to set a default value for the publically exposed class, e.g., `SchNetRadialBasisFunction`) and have optional `dtype` parameters present in the internal classes (which will then fail if it wasn't set in the entry point classes). Also added parameter docstring to some of the RBF classes. --- modelforge/potential/utils.py | 66 +++++++++++++++++++++++++++-------- 1 file changed, 51 insertions(+), 15 deletions(-) diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index c3376b84..9a477104 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -454,7 +454,7 @@ class RadialBasisFunction(nn.Module, ABC): def __init__( self, radial_basis_function: Type[RadialBasisFunctionCore], - dtype, + type:torch.dtype, prefactor: float = 1.0, trainable_prefactor: bool = False, ): @@ -526,9 +526,9 @@ def __init__( Maximum distance to consider for symmetry functions. min_distance: unit.Quantity Minimum distance to consider. - dtype: + dtype: torch.dtype, default None Data type for computations. - prefactor: + prefactor: float Scalar factor by which to multiply output of radial basis functions. trainable_prefactor: bool, default False Whether prefactor is trainable @@ -600,6 +600,9 @@ def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: class SchnetRadialBasisFunction(GaussianRadialBasisFunctionWithScaling): + """ + Implementation of the radial basis function as used by the SchNet neural network + """ def __init__( self, number_of_radial_basis_functions: int, @@ -608,14 +611,20 @@ def __init__( dtype: Optional[torch.dtype] = None, trainable_centers_and_scale_factors: bool = False, ): - """RadialSymmetryFunction class. - - Initializes and contains the logic for computing radial symmetry functions. - + """ Parameters --------- + number_of_radial_basis_functions: int + Number of radial basis functions to use. + max_distance: unit.Quantity + Maximum distance to consider for symmetry functions. + min_distance: unit.Quantity + Minimum distance to consider. + dtype: torch.dtype, default None + Data type for computations. + trainable_centers_and_scale_factors: bool, default False + Whether centers and scale factors are trainable. """ - super().__init__( number_of_radial_basis_functions, max_distance, @@ -662,22 +671,31 @@ def calculate_radial_scale_factor( class AniRadialBasisFunction(GaussianRadialBasisFunctionWithScaling): + """ + Implementation of the radial basis function as used by the ANI neural network + """ def __init__( self, number_of_radial_basis_functions, max_distance: unit.Quantity, min_distance: unit.Quantity = 0.0 * unit.nanometer, - dtype: Optional[torch.dtype] = None, + dtype: torch.dtype = torch.float32, trainable_centers_and_scale_factors: bool = False, ): - """RadialSymmetryFunction class. - - Initializes and contains the logic for computing radial symmetry functions. - + """ Parameters --------- + number_of_radial_basis_functions: int + Number of radial basis functions to use. + max_distance: unit.Quantity + Maximum distance to consider for symmetry functions. + min_distance: unit.Quantity + Minimum distance to consider. + dtype: torch.dtype, default torch.float32 + Data type for computations. + trainable_centers_and_scale_factors: bool, default False + Whether centers and scale factors are trainable. """ - super().__init__( number_of_radial_basis_functions, max_distance, @@ -716,15 +734,33 @@ def calculate_radial_scale_factor( class PhysNetRadialBasisFunction(RadialBasisFunction): + """ + Implementation of the radial basis function as used by the PysNet neural network + """ def __init__( self, number_of_radial_basis_functions: int, max_distance: unit.Quantity, min_distance: unit.Quantity = 0.0 * unit.nanometer, - dtype: Optional[torch.dtype] = None, + dtype: torch.dtype = torch.float32, trainable_centers_and_scale_factors: bool = False, ): + """ + Parameters + ---------- + number_of_radial_basis_functions : int + Number of radial basis functions to use. + max_distance : unit.Quantity + Maximum distance to consider for symmetry functions. + min_distance : unit.Quantity, optional + Minimum distance to consider, by default 0.0 * unit.nanometer. + dtype : torch.dtype, optional + Data type for computations, by default torch.float32. + trainable_centers_and_scale_factors : bool, optional + Whether centers and scale factors are trainable, by default False. + """ + super().__init__(GaussianRadialBasisFunctionCore, trainable_prefactor=False, dtype=dtype) self._max_distance_in_nanometer = max_distance.to(unit.nanometer).m self._min_distance_in_nanometer = min_distance.to(unit.nanometer).m From a4f41b91416ada3aab871e4b0ed20a47b840a132 Mon Sep 17 00:00:00 2001 From: Marcus Wieder <31651017+wiederm@users.noreply.github.com> Date: Wed, 3 Jul 2024 13:50:54 +0200 Subject: [PATCH 22/78] typo fix --- modelforge/potential/utils.py | 238 ++++++++++++++++++---------------- 1 file changed, 128 insertions(+), 110 deletions(-) diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index 9a477104..3c45475d 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -43,7 +43,7 @@ class Metadata: F: torch.Tensor = torch.tensor([], dtype=torch.float32) def to( - self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None + self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None ): """Move all tensors in this instance to the specified device.""" if device: @@ -65,9 +65,9 @@ class BatchData: metadata: Metadata def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): self.nnp_input = self.nnp_input.to(device=device, dtype=dtype) self.metadata = self.metadata.to(device=device, dtype=dtype) @@ -86,7 +86,7 @@ def shared_config_prior(): def triple_by_molecule( - atom_pairs: torch.Tensor, + atom_pairs: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Input: indices for pairs of atoms that are close to each other. each pair only appear once, i.e. only one of the pairs (1, 2) and @@ -122,8 +122,8 @@ def cumsum_from_zero(input_: torch.Tensor) -> torch.Tensor: torch.tril_indices(m, m, -1, device=ai1.device).unsqueeze(1).expand(-1, n, -1) ) mask = ( - torch.arange(intra_pair_indices.shape[2], device=ai1.device) - < pair_sizes.unsqueeze(1) + torch.arange(intra_pair_indices.shape[2], device=ai1.device) + < pair_sizes.unsqueeze(1) ).flatten() sorted_local_index12 = intra_pair_indices.flatten(1, 2)[:, mask] sorted_local_index12 += cumsum_from_zero(counts).index_select(0, pair_indices) @@ -200,13 +200,13 @@ class Dense(nn.Linear): """ def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - activation: Optional[nn.Module] = None, - weight_init: Callable = xavier_uniform_, - bias_init: Callable = zeros_, + self, + in_features: int, + out_features: int, + bias: bool = True, + activation: Optional[nn.Module] = None, + weight_init: Callable = xavier_uniform_, + bias_init: Callable = zeros_, ): """ Args: @@ -268,7 +268,7 @@ def forward(self, d_ij: torch.Tensor): """ # Compute values of cutoff function input_cut = 0.5 * ( - torch.cos(d_ij * np.pi / self.cutoff) + 1.0 + torch.cos(d_ij * np.pi / self.cutoff) + 1.0 ) # NOTE: ANI adds 0.5 instead of 1. # Remove contributions beyond the cutoff radius input_cut *= (d_ij < self.cutoff).float() @@ -312,13 +312,13 @@ class AngularSymmetryFunction(nn.Module): """ def __init__( - self, - max_distance: unit.Quantity, - min_distance: unit.Quantity, - number_of_gaussians_for_asf: int = 8, - angle_sections: int = 4, - trainable: bool = False, - dtype: Optional[torch.dtype] = None, + self, + max_distance: unit.Quantity, + min_distance: unit.Quantity, + number_of_gaussians_for_asf: int = 8, + angle_sections: int = 4, + trainable: bool = False, + dtype: Optional[torch.dtype] = None, ) -> None: """ Parameters @@ -446,17 +446,17 @@ class GaussianRadialBasisFunctionCore(RadialBasisFunctionCore): @staticmethod def compute(nondimensionalized_distances: torch.Tensor) -> torch.Tensor: - return torch.exp(-(nondimensionalized_distances ** 2)) + return torch.exp(-(nondimensionalized_distances**2)) class RadialBasisFunction(nn.Module, ABC): def __init__( - self, - radial_basis_function: Type[RadialBasisFunctionCore], - type:torch.dtype, - prefactor: float = 1.0, - trainable_prefactor: bool = False, + self, + radial_basis_function: Type[RadialBasisFunctionCore], + dtype: torch.dtype, + prefactor: float = 1.0, + trainable_prefactor: bool = False, ): super().__init__() if trainable_prefactor: @@ -507,15 +507,16 @@ class GaussianRadialBasisFunctionWithScaling(RadialBasisFunction): """ Shifts inputs by a set of centers and scales by a set of scale factors before passing into the standard Gaussian. """ + def __init__( - self, - number_of_radial_basis_functions: int, - max_distance: unit.Quantity, - min_distance: unit.Quantity = 0.0 * unit.nanometer, - dtype: Optional[torch.dtype] = None, - prefactor: float = 1.0, - trainable_prefactor: bool = False, - trainable_centers_and_scale_factors: bool = False, + self, + number_of_radial_basis_functions: int, + max_distance: unit.Quantity, + min_distance: unit.Quantity = 0.0 * unit.nanometer, + dtype: Optional[torch.dtype] = None, + prefactor: float = 1.0, + trainable_prefactor: bool = False, + trainable_centers_and_scale_factors: bool = False, ): """ Parameters @@ -536,7 +537,9 @@ def __init__( Whether centers and scale factors are trainable. """ - super().__init__(GaussianRadialBasisFunctionCore, dtype, prefactor, trainable_prefactor) + super().__init__( + GaussianRadialBasisFunctionCore, dtype, prefactor, trainable_prefactor + ) self.number_of_radial_basis_functions = number_of_radial_basis_functions self.dtype = dtype self.trainable_centers_and_scale_factors = trainable_centers_and_scale_factors @@ -556,7 +559,7 @@ def __init__( self.number_of_radial_basis_functions, _max_distance_in_nanometer, _min_distance_in_nanometer, - self.dtype + self.dtype, ) # either add as parameters or register buffers @@ -570,10 +573,10 @@ def __init__( @staticmethod @abstractmethod def calculate_radial_basis_centers( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, ): """ NOTE: centers have units of nanometers @@ -583,10 +586,10 @@ def calculate_radial_basis_centers( @staticmethod @abstractmethod def calculate_radial_scale_factor( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, ): """ NOTE: radial scale factors have units of nanometers @@ -603,13 +606,14 @@ class SchnetRadialBasisFunction(GaussianRadialBasisFunctionWithScaling): """ Implementation of the radial basis function as used by the SchNet neural network """ + def __init__( - self, - number_of_radial_basis_functions: int, - max_distance: unit.Quantity, - min_distance: unit.Quantity = 0.0 * unit.nanometer, - dtype: Optional[torch.dtype] = None, - trainable_centers_and_scale_factors: bool = False, + self, + number_of_radial_basis_functions: int, + max_distance: unit.Quantity, + min_distance: unit.Quantity = 0.0 * unit.nanometer, + dtype: Optional[torch.dtype] = None, + trainable_centers_and_scale_factors: bool = False, ): """ Parameters @@ -636,10 +640,10 @@ def __init__( @staticmethod def calculate_radial_basis_centers( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, ): return torch.linspace( _min_distance_in_nanometer, @@ -650,10 +654,10 @@ def calculate_radial_basis_centers( @staticmethod def calculate_radial_scale_factor( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, ): scale_factors = torch.linspace( _min_distance_in_nanometer, @@ -662,8 +666,8 @@ def calculate_radial_scale_factor( ) widths = ( - torch.abs(scale_factors[1] - scale_factors[0]) - * torch.ones_like(scale_factors) + torch.abs(scale_factors[1] - scale_factors[0]) + * torch.ones_like(scale_factors) ).to(dtype) scale_factors = math.sqrt(2) * widths @@ -674,13 +678,14 @@ class AniRadialBasisFunction(GaussianRadialBasisFunctionWithScaling): """ Implementation of the radial basis function as used by the ANI neural network """ + def __init__( - self, - number_of_radial_basis_functions, - max_distance: unit.Quantity, - min_distance: unit.Quantity = 0.0 * unit.nanometer, - dtype: torch.dtype = torch.float32, - trainable_centers_and_scale_factors: bool = False, + self, + number_of_radial_basis_functions, + max_distance: unit.Quantity, + min_distance: unit.Quantity = 0.0 * unit.nanometer, + dtype: torch.dtype = torch.float32, + trainable_centers_and_scale_factors: bool = False, ): """ Parameters @@ -708,10 +713,10 @@ def __init__( @staticmethod def calculate_radial_basis_centers( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, ): centers = torch.linspace( _min_distance_in_nanometer, @@ -723,13 +728,15 @@ def calculate_radial_basis_centers( @staticmethod def calculate_radial_scale_factor( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, ): # ANI uses a predefined scaling factor - scale_factors = torch.full((number_of_radial_basis_functions,), (19.7 * 100) ** -0.5) + scale_factors = torch.full( + (number_of_radial_basis_functions,), (19.7 * 100) ** -0.5 + ) return scale_factors @@ -739,12 +746,12 @@ class PhysNetRadialBasisFunction(RadialBasisFunction): """ def __init__( - self, - number_of_radial_basis_functions: int, - max_distance: unit.Quantity, - min_distance: unit.Quantity = 0.0 * unit.nanometer, - dtype: torch.dtype = torch.float32, - trainable_centers_and_scale_factors: bool = False, + self, + number_of_radial_basis_functions: int, + max_distance: unit.Quantity, + min_distance: unit.Quantity = 0.0 * unit.nanometer, + dtype: torch.dtype = torch.float32, + trainable_centers_and_scale_factors: bool = False, ): """ Parameters @@ -761,7 +768,9 @@ def __init__( Whether centers and scale factors are trainable, by default False. """ - super().__init__(GaussianRadialBasisFunctionCore, trainable_prefactor=False, dtype=dtype) + super().__init__( + GaussianRadialBasisFunctionCore, trainable_prefactor=False, dtype=dtype + ) self._max_distance_in_nanometer = max_distance.to(unit.nanometer).m self._min_distance_in_nanometer = min_distance.to(unit.nanometer).m radial_basis_centers = self.calculate_radial_basis_centers( @@ -775,7 +784,7 @@ def __init__( number_of_radial_basis_functions, self._max_distance_in_nanometer, self._min_distance_in_nanometer, - dtype + dtype, ) if trainable_centers_and_scale_factors: @@ -787,10 +796,10 @@ def __init__( @staticmethod def calculate_radial_basis_centers( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, ): # initialize centers according to the default values in PhysNet # (see mu_k in Figure 2 caption of https://pubs.acs.org/doi/10.1021/acs.jctc.9b00181) @@ -810,10 +819,10 @@ def calculate_radial_basis_centers( @staticmethod def calculate_radial_scale_factor( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, ): # initialize according to the default values in PhysNet (see beta_k in Figure 2 caption) # NOTES: @@ -823,8 +832,16 @@ def calculate_radial_scale_factor( # of radial_scale_factor in GaussianRadialBasisFunctionWithScaling return torch.full( (number_of_radial_basis_functions,), - (2 * (1 - math.exp(10 * (-_max_distance_in_nanometer + _min_distance_in_nanometer)))) / - number_of_radial_basis_functions, + ( + 2 + * ( + 1 + - math.exp( + 10 * (-_max_distance_in_nanometer + _min_distance_in_nanometer) + ) + ) + ) + / number_of_radial_basis_functions, dtype=dtype, ) @@ -833,14 +850,15 @@ def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: # NOTE: the PhysNet paper implicitly multiplies by 1/Angstrom within the inner exp but distances are in # nanometers, so we multiply by 10/nanometer - return ((torch.exp( - (-distances + self._min_distance_in_nanometer) * 10).unsqueeze(-1) - self.radial_basis_centers) - / self.radial_scale_factor) + return ( + torch.exp((-distances + self._min_distance_in_nanometer) * 10).unsqueeze(-1) + - self.radial_basis_centers + ) / self.radial_scale_factor def pair_list( - atomic_subsystem_indices: torch.Tensor, - only_unique_pairs: bool = False, + atomic_subsystem_indices: torch.Tensor, + only_unique_pairs: bool = False, ) -> torch.Tensor: """Compute all pairs of atoms and their distances. @@ -881,7 +899,7 @@ def pair_list( # filter pairs to only keep those belonging to the same molecule same_molecule_mask = ( - atomic_subsystem_indices[i_indices] == atomic_subsystem_indices[j_indices] + atomic_subsystem_indices[i_indices] == atomic_subsystem_indices[j_indices] ) # Apply mask to get final pair indices @@ -894,9 +912,9 @@ def pair_list( return pair_indices.to(device) def forward( - self, - coordinates: torch.Tensor, # in nanometer - atomic_subsystem_indices: torch.Tensor, + self, + coordinates: torch.Tensor, # in nanometer + atomic_subsystem_indices: torch.Tensor, ) -> torch.Tensor: """Compute all pairs of atoms and their distances. @@ -928,11 +946,11 @@ def forward( def scatter_softmax( - src: torch.Tensor, - index: torch.Tensor, - dim: int, - dim_size: Optional[int] = None, - device: Optional[torch.device] = None, + src: torch.Tensor, + index: torch.Tensor, + dim: int, + dim_size: Optional[int] = None, + device: Optional[torch.device] = None, ) -> torch.Tensor: """ Softmax operation over all values in :attr:`src` tensor that share indices @@ -966,7 +984,7 @@ def scatter_softmax( assert dim >= 0, f"dim must be non-negative, got {dim}" assert ( - dim < src.dim() + dim < src.dim() ), f"dim must be less than the number of dimensions of src {src.dim()}, got {dim}" out_shape = [ From 0e492d7fa0cb48c863c686d716ddd568126d5b76 Mon Sep 17 00:00:00 2001 From: Arnav Nagle <43835955+ArnNag@users.noreply.github.com> Date: Fri, 5 Jul 2024 09:32:20 -0700 Subject: [PATCH 23/78] Make RadialBasisFunctionCore inherit from nn.Module --- modelforge/potential/utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index 3c45475d..4951b3e1 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -425,10 +425,10 @@ def compute_angular_sub_aev(self, vectors12: torch.Tensor) -> torch.Tensor: from abc import ABC, abstractmethod -class RadialBasisFunctionCore(ABC): - @staticmethod +class RadialBasisFunctionCore(nn.Module, ABC): + @abstractmethod - def compute(nondimensionalized_distances: torch.Tensor) -> torch.Tensor: + def forward(self, nondimensionalized_distances: torch.Tensor) -> torch.Tensor: """ Parameters --------- @@ -445,7 +445,7 @@ def compute(nondimensionalized_distances: torch.Tensor) -> torch.Tensor: class GaussianRadialBasisFunctionCore(RadialBasisFunctionCore): @staticmethod - def compute(nondimensionalized_distances: torch.Tensor) -> torch.Tensor: + def forward(self, nondimensionalized_distances: torch.Tensor) -> torch.Tensor: return torch.exp(-(nondimensionalized_distances**2)) @@ -453,7 +453,7 @@ class RadialBasisFunction(nn.Module, ABC): def __init__( self, - radial_basis_function: Type[RadialBasisFunctionCore], + radial_basis_function: RadialBasisFunctionCore, dtype: torch.dtype, prefactor: float = 1.0, trainable_prefactor: bool = False, @@ -498,7 +498,7 @@ def forward(self, distances: torch.Tensor) -> torch.Tensor: Output of radial basis functions. """ nondimensionalized_distances = self.nondimensionalize_distances(distances) - return self.prefactor * self.radial_basis_function.compute( + return self.prefactor * self.radial_basis_function( nondimensionalized_distances ) @@ -538,7 +538,7 @@ def __init__( """ super().__init__( - GaussianRadialBasisFunctionCore, dtype, prefactor, trainable_prefactor + GaussianRadialBasisFunctionCore(), dtype, prefactor, trainable_prefactor ) self.number_of_radial_basis_functions = number_of_radial_basis_functions self.dtype = dtype @@ -769,7 +769,7 @@ def __init__( """ super().__init__( - GaussianRadialBasisFunctionCore, trainable_prefactor=False, dtype=dtype + GaussianRadialBasisFunctionCore(), trainable_prefactor=False, dtype=dtype ) self._max_distance_in_nanometer = max_distance.to(unit.nanometer).m self._min_distance_in_nanometer = min_distance.to(unit.nanometer).m From bb387c5345c4ac9a93ddae9ac7932366c7817b32 Mon Sep 17 00:00:00 2001 From: Arnav Nagle <43835955+ArnNag@users.noreply.github.com> Date: Fri, 5 Jul 2024 15:25:07 -0700 Subject: [PATCH 24/78] Remove @staticmethod decorator --- modelforge/potential/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index 4951b3e1..c0daee52 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -444,7 +444,6 @@ def forward(self, nondimensionalized_distances: torch.Tensor) -> torch.Tensor: class GaussianRadialBasisFunctionCore(RadialBasisFunctionCore): - @staticmethod def forward(self, nondimensionalized_distances: torch.Tensor) -> torch.Tensor: return torch.exp(-(nondimensionalized_distances**2)) From 5f35b728c9ad861350c47ead65a4c7b66d661511 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Fri, 5 Jul 2024 17:25:17 -0700 Subject: [PATCH 25/78] Try to implement Bernstein polynomials --- modelforge/potential/utils.py | 146 +++++++++++++++-------------- modelforge/tests/test_spookynet.py | 4 +- 2 files changed, 79 insertions(+), 71 deletions(-) diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index 5ea98f4e..6ac2e1d4 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -282,6 +282,7 @@ class SpookyNetCutoff(nn.Module): electronic degrees of freedom and nonlocal effects. Nat Commun 12, 7273 (2021). Adapted from https://github.com/OUnke/SpookyNet/blob/d57b1fc02c4f1304a9445b2b9aa55a906818dd1b/spookynet/functional.py#L19 # noqa """ + def __init__(self, cutoff: unit.Quantity): """ @@ -308,80 +309,31 @@ def forward(self, d_ij: torch.Tensor): ) -class ExponentialBernsteinPolynomials(nn.Module): - """ - Taken from SpookyNet. - Radial basis functions based on exponential Bernstein polynomials given by: - b_{v,n}(x) = (n over v) * exp(-alpha*x)**v * (1-exp(-alpha*x))**(n-v) - (see https://en.wikipedia.org/wiki/Bernstein_polynomial) - Here, n = num_basis_functions-1 and v takes values from 0 to n. This - implementation operates in log space to prevent multiplication of very large - (n over v) and very small numbers (exp(-alpha*x)**v and - (1-exp(-alpha*x))**(n-v)) for numerical stability. - NOTE: There is a problem for x = 0, as log(-expm1(0)) will be log(0) = -inf. - This itself is not an issue, but the buffer v contains an entry 0 and - 0*(-inf)=nan. The correct behaviour could be recovered by replacing the nan - with 0.0, but should not be necessary because issues are only present when - r = 0, which will not occur with chemically meaningful inputs. - - Arguments: - num_basis_functions (int): - Number of radial basis functions. - x = infinity. - ini_alpha (float): - Initial value for scaling parameter alpha (Default value corresponds - to 0.5 1/Bohr converted to 1/Angstrom). - """ +class ExponentialBernsteinPolynomialsFactory: - def __init__( - self, - num_basis_functions: int, - ini_alpha: Quantity = 0.5 / unit.bohr, - dtype: Optional[torch.dtype] = None, - ) -> None: - """ Initializes the ExponentialBernsteinPolynomials class. """ - super(ExponentialBernsteinPolynomials, self).__init__() - self.ini_alpha = ini_alpha.to(1 / unit.nanometer).m - # compute values to initialize buffers - logfactorial = np.zeros(num_basis_functions) - for i in range(2, num_basis_functions): + @staticmethod + def make_radial_basis_function(number_of_radial_basis_functions: int, dtype: torch.dtype): + logfactorial = np.zeros(number_of_radial_basis_functions) + for i in range(2, number_of_radial_basis_functions): logfactorial[i] = logfactorial[i - 1] + np.log(i) - v = np.arange(0, num_basis_functions) - n = (num_basis_functions - 1) - v + v = np.arange(0, number_of_radial_basis_functions) + n = (number_of_radial_basis_functions - 1) - v logbinomial = logfactorial[-1] - logfactorial[v] - logfactorial[n] # register buffers and parameters - self.register_buffer("logc", torch.tensor(logbinomial, dtype=dtype)) - self.register_buffer("n", torch.tensor(n, dtype=dtype)) - self.register_buffer("v", torch.tensor(v, dtype=dtype)) - self.register_parameter( - "_alpha", nn.Parameter(torch.tensor(1.0, dtype=dtype)) - ) - self.reset_parameters() + radial_basis_function = ExponentialBernsteinPolynomialsCore + radial_basis_function.logc = torch.tensor(logbinomial, dtype=dtype) + radial_basis_function.n = torch.tensor(n, dtype=dtype) + radial_basis_function.v = torch.tensor(v, dtype=dtype) + return radial_basis_function - def reset_parameters(self) -> None: - """ Initialize exponential scaling parameter alpha. """ - nn.init.constant_(self._alpha, softplus_inverse(self.ini_alpha)) - - def forward(self, r: torch.Tensor) -> torch.Tensor: - """ - Evaluates radial basis functions given distances - N: Number of input values. - num_basis_functions: Number of radial basis functions. - Arguments: - r (FloatTensor [N]): - Input distances. +class ExponentialBernsteinPolynomials: + nn.init.constant_(self.alpha, alpha) + self.reset_parameters(ini_alpha.to(unit.nanometer).m) - Returns: - rbf (FloatTensor [N, num_basis_functions]): - Values of the radial basis functions for the distances r. - """ - alphar = -F.softplus(self._alpha) * r.view(-1, 1) - x = self.logc + self.n * alphar + self.v * torch.log(-torch.expm1(alphar)) - print(f"{self.logc.shape=}") - rbf = torch.exp(x) - return rbf * torch.exp(alphar) +def nondimensionalize_distances(self, d_ij: torch.Tensor) -> torch.Tensor: + return -(d_ij.view(-1, 1) / self.alpha) from typing import Dict @@ -413,6 +365,7 @@ def forward(self, x: torch.Tensor): return functional.softplus(x) - self.log_2 + def softplus_inverse(x): """ From SpookyNet: @@ -545,7 +498,9 @@ def compute_angular_sub_aev(self, vectors12: torch.Tensor) -> torch.Tensor: class RadialBasisFunctionCore(ABC): - @staticmethod + + @abstractmethod + def __init__(self, ): @abstractmethod def compute(nondimensionalized_distances: torch.Tensor) -> torch.Tensor: """ @@ -563,16 +518,67 @@ def compute(nondimensionalized_distances: torch.Tensor) -> torch.Tensor: class GaussianRadialBasisFunctionCore(RadialBasisFunctionCore): - @staticmethod - def compute(nondimensionalized_distances: torch.Tensor) -> torch.Tensor: + def compute(self, nondimensionalized_distances: torch.Tensor) -> torch.Tensor: return torch.exp(-(nondimensionalized_distances ** 2)) +class ExponentialBernsteinPolynomialsCore(RadialBasisFunctionCore): + """ + Taken from SpookyNet. + Radial basis functions based on exponential Bernstein polynomials given by: + b_{v,n}(x) = (n over v) * exp(-alpha*x)**v * (1-exp(-alpha*x))**(n-v) + (see https://en.wikipedia.org/wiki/Bernstein_polynomial) + Here, n = num_basis_functions-1 and v takes values from 0 to n. This + implementation operates in log space to prevent multiplication of very large + (n over v) and very small numbers (exp(-alpha*x)**v and + (1-exp(-alpha*x))**(n-v)) for numerical stability. + NOTE: There is a problem for x = 0, as log(-expm1(0)) will be log(0) = -inf. + This itself is not an issue, but the buffer v contains an entry 0 and + 0*(-inf)=nan. The correct behaviour could be recovered by replacing the nan + with 0.0, but should not be necessary because issues are only present when + r = 0, which will not occur with chemically meaningful inputs. + + Arguments: + num_basis_functions (int): + Number of radial basis functions. + x = infinity. + ini_alpha (float): + Initial value for scaling parameter alpha (alpha here is the reciprocal of alpha in the paper. The original + default is 0.5/bohr, so we use 2 bohr). + """ + + def __init__(self, number_of_radial_basis_functions: int): + + + def reset_parameters(self, alpha) -> None: + """ Initialize exponential scaling parameter alpha. """ + + def compute(self, nondimensionalized_distances: torch.Tensor) -> torch.Tensor: + """ + Evaluates radial basis functions given distances + N: Number of input values. + num_basis_functions: Number of radial basis functions. + + Arguments: + nondimensionalized_distances (FloatTensor [N]): + Input distances. + + Returns: + rbf (FloatTensor [N, num_basis_functions]): + Values of the radial basis functions for the distances r. + """ + x = (cls.logc + (cls.n + 1) * nondimensionalized_distances + + v * torch.log(-torch.expm1(nondimensionalized_distances))) + print(f"{cls.logc.shape=}") + + return torch.exp(x) + + class RadialBasisFunction(nn.Module, ABC): def __init__( self, - radial_basis_function: Type[RadialBasisFunctionCore], + radial_basis_function: Callable[[torch.Tensor], torch.Tensor], dtype, prefactor: float = 1.0, trainable_prefactor: bool = False, diff --git a/modelforge/tests/test_spookynet.py b/modelforge/tests/test_spookynet.py index 69d93614..34ad2e3d 100644 --- a/modelforge/tests/test_spookynet.py +++ b/modelforge/tests/test_spookynet.py @@ -205,4 +205,6 @@ def test_spookynet_bernstein_polynomial_equivalence(): cutoff_values = torch.rand((N, 1)) ref_exp_bernstein_polynomial_result = ref_exp_bernstein_polynomials(r_angstrom, cutoff_values) mf_exp_bernstein_polynomial_result = mf_exp_bernstein_polynomials(r_nanometer) * cutoff_values - assert torch.equal(ref_exp_bernstein_polynomial_result, mf_exp_bernstein_polynomial_result) + print(f"{ref_exp_bernstein_polynomial_result=}") + print(f"{mf_exp_bernstein_polynomial_result=}") + assert torch.allclose(ref_exp_bernstein_polynomial_result, mf_exp_bernstein_polynomial_result) From 22acdc7cc428be635be42a82b193d59abe3298c2 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Fri, 5 Jul 2024 17:53:43 -0700 Subject: [PATCH 26/78] Add shape assertion in RadialBasisFunctionCore. Remove unnecessary unsqueeze in PhysNet radial basis function --- modelforge/potential/utils.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index c0daee52..f4eba8c3 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -426,7 +426,11 @@ def compute_angular_sub_aev(self, vectors12: torch.Tensor) -> torch.Tensor: class RadialBasisFunctionCore(nn.Module, ABC): - + + def __init__(self, number_of_radial_basis_functions): + super().__init__() + self.number_of_radial_basis_functions = number_of_radial_basis_functions + @abstractmethod def forward(self, nondimensionalized_distances: torch.Tensor) -> torch.Tensor: """ @@ -445,6 +449,12 @@ def forward(self, nondimensionalized_distances: torch.Tensor) -> torch.Tensor: class GaussianRadialBasisFunctionCore(RadialBasisFunctionCore): def forward(self, nondimensionalized_distances: torch.Tensor) -> torch.Tensor: + assert nondimensionalized_distances.ndim == 2 + assert ( + nondimensionalized_distances.shape[1] + == self.number_of_radial_basis_functions + ) + return torch.exp(-(nondimensionalized_distances**2)) @@ -496,10 +506,9 @@ def forward(self, distances: torch.Tensor) -> torch.Tensor: torch.Tensor, shape [number_of_pairs, number_of_radial_basis_functions] Output of radial basis functions. """ + print(f"{distances.shape=}") nondimensionalized_distances = self.nondimensionalize_distances(distances) - return self.prefactor * self.radial_basis_function( - nondimensionalized_distances - ) + return self.prefactor * self.radial_basis_function(nondimensionalized_distances) class GaussianRadialBasisFunctionWithScaling(RadialBasisFunction): @@ -537,7 +546,10 @@ def __init__( """ super().__init__( - GaussianRadialBasisFunctionCore(), dtype, prefactor, trainable_prefactor + GaussianRadialBasisFunctionCore(number_of_radial_basis_functions), + dtype, + prefactor, + trainable_prefactor, ) self.number_of_radial_basis_functions = number_of_radial_basis_functions self.dtype = dtype @@ -768,7 +780,9 @@ def __init__( """ super().__init__( - GaussianRadialBasisFunctionCore(), trainable_prefactor=False, dtype=dtype + GaussianRadialBasisFunctionCore(number_of_radial_basis_functions), + trainable_prefactor=False, + dtype=dtype, ) self._max_distance_in_nanometer = max_distance.to(unit.nanometer).m self._min_distance_in_nanometer = min_distance.to(unit.nanometer).m @@ -850,7 +864,7 @@ def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: # nanometers, so we multiply by 10/nanometer return ( - torch.exp((-distances + self._min_distance_in_nanometer) * 10).unsqueeze(-1) + torch.exp((-distances + self._min_distance_in_nanometer) * 10) - self.radial_basis_centers ) / self.radial_scale_factor From cb28ef17ae99b454866d6a80c5cd21bad1c2f2c6 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Fri, 5 Jul 2024 17:54:16 -0700 Subject: [PATCH 27/78] Remove print statement --- modelforge/potential/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index f4eba8c3..0f373d45 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -506,7 +506,6 @@ def forward(self, distances: torch.Tensor) -> torch.Tensor: torch.Tensor, shape [number_of_pairs, number_of_radial_basis_functions] Output of radial basis functions. """ - print(f"{distances.shape=}") nondimensionalized_distances = self.nondimensionalize_distances(distances) return self.prefactor * self.radial_basis_function(nondimensionalized_distances) From 1e887835249ad20dbcf09788a15d742db3f8e6be Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Fri, 5 Jul 2024 18:14:28 -0700 Subject: [PATCH 28/78] Remove print statements in SAKE --- modelforge/potential/sake.py | 3 - modelforge/potential/utils.py | 218 +++++++++++++++++----------------- 2 files changed, 109 insertions(+), 112 deletions(-) diff --git a/modelforge/potential/sake.py b/modelforge/potential/sake.py index ffc2b95c..06fa4331 100644 --- a/modelforge/potential/sake.py +++ b/modelforge/potential/sake.py @@ -322,12 +322,9 @@ def update_edge(self, h_i_by_pair, h_j_by_pair, d_ij): Intermediate edge features. Shape [nr_pairs, nr_edge_basis]. """ h_ij_cat = torch.cat([h_i_by_pair, h_j_by_pair], dim=-1) - print(f"{self.radial_symmetry_function_module(d_ij.unsqueeze(-1)).shape=}") - print(f"{self.edge_mlp_in(h_ij_cat).shape=}") h_ij_filtered = self.radial_symmetry_function_module(d_ij.unsqueeze(-1)).squeeze(-2) * self.edge_mlp_in( h_ij_cat ) - print(f"{h_ij_filtered.shape=}, {h_ij_cat.shape=}, {d_ij.shape=}") return self.edge_mlp_out( torch.cat([h_ij_cat, h_ij_filtered, d_ij.unsqueeze(-1)], dim=-1) ) diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index 0f373d45..86bc5fa8 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -43,7 +43,7 @@ class Metadata: F: torch.Tensor = torch.tensor([], dtype=torch.float32) def to( - self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None + self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None ): """Move all tensors in this instance to the specified device.""" if device: @@ -65,9 +65,9 @@ class BatchData: metadata: Metadata def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): self.nnp_input = self.nnp_input.to(device=device, dtype=dtype) self.metadata = self.metadata.to(device=device, dtype=dtype) @@ -86,7 +86,7 @@ def shared_config_prior(): def triple_by_molecule( - atom_pairs: torch.Tensor, + atom_pairs: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Input: indices for pairs of atoms that are close to each other. each pair only appear once, i.e. only one of the pairs (1, 2) and @@ -122,8 +122,8 @@ def cumsum_from_zero(input_: torch.Tensor) -> torch.Tensor: torch.tril_indices(m, m, -1, device=ai1.device).unsqueeze(1).expand(-1, n, -1) ) mask = ( - torch.arange(intra_pair_indices.shape[2], device=ai1.device) - < pair_sizes.unsqueeze(1) + torch.arange(intra_pair_indices.shape[2], device=ai1.device) + < pair_sizes.unsqueeze(1) ).flatten() sorted_local_index12 = intra_pair_indices.flatten(1, 2)[:, mask] sorted_local_index12 += cumsum_from_zero(counts).index_select(0, pair_indices) @@ -200,13 +200,13 @@ class Dense(nn.Linear): """ def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - activation: Optional[nn.Module] = None, - weight_init: Callable = xavier_uniform_, - bias_init: Callable = zeros_, + self, + in_features: int, + out_features: int, + bias: bool = True, + activation: Optional[nn.Module] = None, + weight_init: Callable = xavier_uniform_, + bias_init: Callable = zeros_, ): """ Args: @@ -268,7 +268,7 @@ def forward(self, d_ij: torch.Tensor): """ # Compute values of cutoff function input_cut = 0.5 * ( - torch.cos(d_ij * np.pi / self.cutoff) + 1.0 + torch.cos(d_ij * np.pi / self.cutoff) + 1.0 ) # NOTE: ANI adds 0.5 instead of 1. # Remove contributions beyond the cutoff radius input_cut *= (d_ij < self.cutoff).float() @@ -312,13 +312,13 @@ class AngularSymmetryFunction(nn.Module): """ def __init__( - self, - max_distance: unit.Quantity, - min_distance: unit.Quantity, - number_of_gaussians_for_asf: int = 8, - angle_sections: int = 4, - trainable: bool = False, - dtype: Optional[torch.dtype] = None, + self, + max_distance: unit.Quantity, + min_distance: unit.Quantity, + number_of_gaussians_for_asf: int = 8, + angle_sections: int = 4, + trainable: bool = False, + dtype: Optional[torch.dtype] = None, ) -> None: """ Parameters @@ -451,21 +451,21 @@ class GaussianRadialBasisFunctionCore(RadialBasisFunctionCore): def forward(self, nondimensionalized_distances: torch.Tensor) -> torch.Tensor: assert nondimensionalized_distances.ndim == 2 assert ( - nondimensionalized_distances.shape[1] - == self.number_of_radial_basis_functions + nondimensionalized_distances.shape[1] + == self.number_of_radial_basis_functions ) - return torch.exp(-(nondimensionalized_distances**2)) + return torch.exp(-(nondimensionalized_distances ** 2)) class RadialBasisFunction(nn.Module, ABC): def __init__( - self, - radial_basis_function: RadialBasisFunctionCore, - dtype: torch.dtype, - prefactor: float = 1.0, - trainable_prefactor: bool = False, + self, + radial_basis_function: RadialBasisFunctionCore, + dtype: torch.dtype, + prefactor: float = 1.0, + trainable_prefactor: bool = False, ): super().__init__() if trainable_prefactor: @@ -516,14 +516,14 @@ class GaussianRadialBasisFunctionWithScaling(RadialBasisFunction): """ def __init__( - self, - number_of_radial_basis_functions: int, - max_distance: unit.Quantity, - min_distance: unit.Quantity = 0.0 * unit.nanometer, - dtype: Optional[torch.dtype] = None, - prefactor: float = 1.0, - trainable_prefactor: bool = False, - trainable_centers_and_scale_factors: bool = False, + self, + number_of_radial_basis_functions: int, + max_distance: unit.Quantity, + min_distance: unit.Quantity = 0.0 * unit.nanometer, + dtype: Optional[torch.dtype] = None, + prefactor: float = 1.0, + trainable_prefactor: bool = False, + trainable_centers_and_scale_factors: bool = False, ): """ Parameters @@ -583,10 +583,10 @@ def __init__( @staticmethod @abstractmethod def calculate_radial_basis_centers( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, ): """ NOTE: centers have units of nanometers @@ -596,10 +596,10 @@ def calculate_radial_basis_centers( @staticmethod @abstractmethod def calculate_radial_scale_factor( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, ): """ NOTE: radial scale factors have units of nanometers @@ -618,12 +618,12 @@ class SchnetRadialBasisFunction(GaussianRadialBasisFunctionWithScaling): """ def __init__( - self, - number_of_radial_basis_functions: int, - max_distance: unit.Quantity, - min_distance: unit.Quantity = 0.0 * unit.nanometer, - dtype: Optional[torch.dtype] = None, - trainable_centers_and_scale_factors: bool = False, + self, + number_of_radial_basis_functions: int, + max_distance: unit.Quantity, + min_distance: unit.Quantity = 0.0 * unit.nanometer, + dtype: Optional[torch.dtype] = None, + trainable_centers_and_scale_factors: bool = False, ): """ Parameters @@ -650,10 +650,10 @@ def __init__( @staticmethod def calculate_radial_basis_centers( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, ): return torch.linspace( _min_distance_in_nanometer, @@ -664,10 +664,10 @@ def calculate_radial_basis_centers( @staticmethod def calculate_radial_scale_factor( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, ): scale_factors = torch.linspace( _min_distance_in_nanometer, @@ -676,8 +676,8 @@ def calculate_radial_scale_factor( ) widths = ( - torch.abs(scale_factors[1] - scale_factors[0]) - * torch.ones_like(scale_factors) + torch.abs(scale_factors[1] - scale_factors[0]) + * torch.ones_like(scale_factors) ).to(dtype) scale_factors = math.sqrt(2) * widths @@ -690,12 +690,12 @@ class AniRadialBasisFunction(GaussianRadialBasisFunctionWithScaling): """ def __init__( - self, - number_of_radial_basis_functions, - max_distance: unit.Quantity, - min_distance: unit.Quantity = 0.0 * unit.nanometer, - dtype: torch.dtype = torch.float32, - trainable_centers_and_scale_factors: bool = False, + self, + number_of_radial_basis_functions, + max_distance: unit.Quantity, + min_distance: unit.Quantity = 0.0 * unit.nanometer, + dtype: torch.dtype = torch.float32, + trainable_centers_and_scale_factors: bool = False, ): """ Parameters @@ -723,10 +723,10 @@ def __init__( @staticmethod def calculate_radial_basis_centers( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, ): centers = torch.linspace( _min_distance_in_nanometer, @@ -738,10 +738,10 @@ def calculate_radial_basis_centers( @staticmethod def calculate_radial_scale_factor( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, ): # ANI uses a predefined scaling factor scale_factors = torch.full( @@ -756,12 +756,12 @@ class PhysNetRadialBasisFunction(RadialBasisFunction): """ def __init__( - self, - number_of_radial_basis_functions: int, - max_distance: unit.Quantity, - min_distance: unit.Quantity = 0.0 * unit.nanometer, - dtype: torch.dtype = torch.float32, - trainable_centers_and_scale_factors: bool = False, + self, + number_of_radial_basis_functions: int, + max_distance: unit.Quantity, + min_distance: unit.Quantity = 0.0 * unit.nanometer, + dtype: torch.dtype = torch.float32, + trainable_centers_and_scale_factors: bool = False, ): """ Parameters @@ -808,10 +808,10 @@ def __init__( @staticmethod def calculate_radial_basis_centers( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, ): # initialize centers according to the default values in PhysNet # (see mu_k in Figure 2 caption of https://pubs.acs.org/doi/10.1021/acs.jctc.9b00181) @@ -831,10 +831,10 @@ def calculate_radial_basis_centers( @staticmethod def calculate_radial_scale_factor( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, ): # initialize according to the default values in PhysNet (see beta_k in Figure 2 caption) # NOTES: @@ -845,13 +845,13 @@ def calculate_radial_scale_factor( return torch.full( (number_of_radial_basis_functions,), ( - 2 - * ( - 1 - - math.exp( + 2 + * ( + 1 + - math.exp( 10 * (-_max_distance_in_nanometer + _min_distance_in_nanometer) ) - ) + ) ) / number_of_radial_basis_functions, dtype=dtype, @@ -863,14 +863,14 @@ def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: # nanometers, so we multiply by 10/nanometer return ( - torch.exp((-distances + self._min_distance_in_nanometer) * 10) - - self.radial_basis_centers + torch.exp((-distances + self._min_distance_in_nanometer) * 10) + - self.radial_basis_centers ) / self.radial_scale_factor def pair_list( - atomic_subsystem_indices: torch.Tensor, - only_unique_pairs: bool = False, + atomic_subsystem_indices: torch.Tensor, + only_unique_pairs: bool = False, ) -> torch.Tensor: """Compute all pairs of atoms and their distances. @@ -911,7 +911,7 @@ def pair_list( # filter pairs to only keep those belonging to the same molecule same_molecule_mask = ( - atomic_subsystem_indices[i_indices] == atomic_subsystem_indices[j_indices] + atomic_subsystem_indices[i_indices] == atomic_subsystem_indices[j_indices] ) # Apply mask to get final pair indices @@ -924,9 +924,9 @@ def pair_list( return pair_indices.to(device) def forward( - self, - coordinates: torch.Tensor, # in nanometer - atomic_subsystem_indices: torch.Tensor, + self, + coordinates: torch.Tensor, # in nanometer + atomic_subsystem_indices: torch.Tensor, ) -> torch.Tensor: """Compute all pairs of atoms and their distances. @@ -958,11 +958,11 @@ def forward( def scatter_softmax( - src: torch.Tensor, - index: torch.Tensor, - dim: int, - dim_size: Optional[int] = None, - device: Optional[torch.device] = None, + src: torch.Tensor, + index: torch.Tensor, + dim: int, + dim_size: Optional[int] = None, + device: Optional[torch.device] = None, ) -> torch.Tensor: """ Softmax operation over all values in :attr:`src` tensor that share indices @@ -996,7 +996,7 @@ def scatter_softmax( assert dim >= 0, f"dim must be non-negative, got {dim}" assert ( - dim < src.dim() + dim < src.dim() ), f"dim must be less than the number of dimensions of src {src.dim()}, got {dim}" out_shape = [ From 5b0e9bc7a75aa209921b3b758ea6fca228f2d72b Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Fri, 5 Jul 2024 18:17:13 -0700 Subject: [PATCH 29/78] Fix SAKE RBF test --- modelforge/tests/test_sake.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelforge/tests/test_sake.py b/modelforge/tests/test_sake.py index 0e822499..ae828b1d 100644 --- a/modelforge/tests/test_sake.py +++ b/modelforge/tests/test_sake.py @@ -242,7 +242,7 @@ def test_radial_symmetry_function_against_reference(): d_ij_jax = jax.random.uniform(key, (nr_atoms, nr_atoms, 1)) d_ij = torch.from_numpy( onp.array(d_ij_jax) - ).reshape(nr_atoms ** 2) + ).reshape((nr_atoms ** 2, 1)) mf_rbf = radial_symmetry_function_module(d_ij) variables = ref_radial_basis_module.init(key, d_ij_jax) From addc502ef755102be868ed57f45aada627fe136c Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Tue, 9 Jul 2024 09:08:37 -0700 Subject: [PATCH 30/78] Refactor exponential Bernstein polynomials --- modelforge/potential/utils.py | 70 ++++++++++++++++------------------- 1 file changed, 31 insertions(+), 39 deletions(-) diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index 6ac2e1d4..b98e4567 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -309,33 +309,6 @@ def forward(self, d_ij: torch.Tensor): ) -class ExponentialBernsteinPolynomialsFactory: - - @staticmethod - def make_radial_basis_function(number_of_radial_basis_functions: int, dtype: torch.dtype): - logfactorial = np.zeros(number_of_radial_basis_functions) - for i in range(2, number_of_radial_basis_functions): - logfactorial[i] = logfactorial[i - 1] + np.log(i) - v = np.arange(0, number_of_radial_basis_functions) - n = (number_of_radial_basis_functions - 1) - v - logbinomial = logfactorial[-1] - logfactorial[v] - logfactorial[n] - # register buffers and parameters - radial_basis_function = ExponentialBernsteinPolynomialsCore - radial_basis_function.logc = torch.tensor(logbinomial, dtype=dtype) - radial_basis_function.n = torch.tensor(n, dtype=dtype) - radial_basis_function.v = torch.tensor(v, dtype=dtype) - return radial_basis_function - - -class ExponentialBernsteinPolynomials: - nn.init.constant_(self.alpha, alpha) - self.reset_parameters(ini_alpha.to(unit.nanometer).m) - - -def nondimensionalize_distances(self, d_ij: torch.Tensor) -> torch.Tensor: - return -(d_ij.view(-1, 1) / self.alpha) - - from typing import Dict @@ -539,21 +512,27 @@ class ExponentialBernsteinPolynomialsCore(RadialBasisFunctionCore): r = 0, which will not occur with chemically meaningful inputs. Arguments: - num_basis_functions (int): + number_of_radial_basis_functions (int): Number of radial basis functions. x = infinity. - ini_alpha (float): - Initial value for scaling parameter alpha (alpha here is the reciprocal of alpha in the paper. The original - default is 0.5/bohr, so we use 2 bohr). """ - def __init__(self, number_of_radial_basis_functions: int): + def __init__(self, number_of_radial_basis_functions: int): + logfactorial = np.zeros(number_of_radial_basis_functions) + for i in range(2, number_of_radial_basis_functions): + logfactorial[i] = logfactorial[i - 1] + np.log(i) + v = np.arange(0, number_of_radial_basis_functions) + n = (number_of_radial_basis_functions - 1) - v + logbinomial = logfactorial[-1] - logfactorial[v] - logfactorial[n] + # register buffers and parameters + dtype = torch.float64 # TODO: make this a parameter + self.logc = torch.tensor(logbinomial, dtype=dtype) + self.n = torch.tensor(n, dtype=dtype) + self.v = torch.tensor(v, dtype=dtype) - def reset_parameters(self, alpha) -> None: - """ Initialize exponential scaling parameter alpha. """ - def compute(self, nondimensionalized_distances: torch.Tensor) -> torch.Tensor: + def forward(self, nondimensionalized_distances: torch.Tensor) -> torch.Tensor: """ Evaluates radial basis functions given distances N: Number of input values. @@ -567,9 +546,9 @@ def compute(self, nondimensionalized_distances: torch.Tensor) -> torch.Tensor: rbf (FloatTensor [N, num_basis_functions]): Values of the radial basis functions for the distances r. """ - x = (cls.logc + (cls.n + 1) * nondimensionalized_distances - + v * torch.log(-torch.expm1(nondimensionalized_distances))) - print(f"{cls.logc.shape=}") + x = (self.logc + (self.n + 1) * nondimensionalized_distances + + self.v * torch.log(-torch.expm1(nondimensionalized_distances))) + print(f"{self.logc.shape=}") return torch.exp(x) @@ -661,7 +640,7 @@ def __init__( Whether centers and scale factors are trainable. """ - super().__init__(GaussianRadialBasisFunctionCore, dtype, prefactor, trainable_prefactor) + super().__init__(GaussianRadialBasisFunctionCore(), dtype, prefactor, trainable_prefactor) self.number_of_radial_basis_functions = number_of_radial_basis_functions self.dtype = dtype self.trainable_centers_and_scale_factors = trainable_centers_and_scale_factors @@ -927,6 +906,19 @@ def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: / self.radial_scale_factor) +class ExponentialBernsteinRadialBasisFunction(RadialBasisFunction): + + def __init__(self, ini_alpha): + """ + ini_alpha (float): + Initial value for scaling parameter alpha (alpha here is the reciprocal of alpha in the paper. The original + default is 0.5/bohr, so we use 2 bohr). + """ + self.ini_alpha = ini_alpha + def nondimensionalize_distances(self, d_ij: torch.Tensor) -> torch.Tensor: + return -(d_ij.view(-1, 1) / self.alpha) + + def pair_list( atomic_subsystem_indices: torch.Tensor, only_unique_pairs: bool = False, From 1f11f3a525ca797dc1a9adf6fe12935a02dd56d2 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Tue, 9 Jul 2024 09:30:01 -0700 Subject: [PATCH 31/78] Fix SchNet tests --- modelforge/tests/test_schnet.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modelforge/tests/test_schnet.py b/modelforge/tests/test_schnet.py index 5ede7fb0..0ba25a2f 100644 --- a/modelforge/tests/test_schnet.py +++ b/modelforge/tests/test_schnet.py @@ -55,7 +55,7 @@ def test_Schnet_init(): def test_compare_radial_symmetry_features(): # compare schnetpack RadialSymmetryFunction with modelforge RadialSymmetryFunction - from modelforge.potential.utils import SchnetRadialSymmetryFunction + from modelforge.potential.utils import SchnetRadialBasisFunction from openff.units import unit # Initialize the RBFs @@ -63,7 +63,7 @@ def test_compare_radial_symmetry_features(): cutoff = unit.Quantity(5.2, unit.angstrom) start = unit.Quantity(0.8, unit.angstrom) - radial_symmetry_function_module = SchnetRadialSymmetryFunction( + radial_symmetry_function_module = SchnetRadialBasisFunction( number_of_radial_basis_functions=number_of_gaussians, max_distance=cutoff, min_distance=start, @@ -228,10 +228,10 @@ def test_schnet_forward_pass(): dtype=torch.float64, ) modelforge_phi_ij = modelforge_schnet.core_module.schnet_representation_module.radial_symmetry_function_module( - d_ij.unsqueeze(1) / 10 + d_ij / 10 ) # NOTE: converting to nm - assert torch.allclose(schnetpack_phi_ij, modelforge_phi_ij, atol=1e-3) + assert torch.allclose(schnetpack_phi_ij, modelforge_phi_ij.unsqueeze(1), atol=1e-3) # ---------------------------------------- # # test cutoff # ---------------------------------------- # From 2109f3c674c108bf3d92bdef5ce011d692f69f6f Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Tue, 9 Jul 2024 09:41:17 -0700 Subject: [PATCH 32/78] Fix spk tests --- modelforge/tests/test_spk.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/modelforge/tests/test_spk.py b/modelforge/tests/test_spk.py index 6bf5ffa5..e50bcafd 100644 --- a/modelforge/tests/test_spk.py +++ b/modelforge/tests/test_spk.py @@ -209,8 +209,6 @@ def test_painn_representation_implementation(): for i in range(nr_of_interactions) for dense in schnetpack_painn.interactions[i].interatomic_context_net ] - print(modelforge_painn.core_module.interaction_modules[0].interatomic_net[0].weight) - print(schnetpack_painn.interactions[0].interatomic_context_net[0].weight) assert torch.allclose( modelforge_painn.core_module.interaction_modules[0].interatomic_net[0].weight, @@ -522,10 +520,10 @@ def test_schnet_representation_implementation(): d_ij = torch.norm(r_ij, dim=1, keepdim=True) schnetpack_phi_ij = schnetpack_schnet.radial_basis(d_ij) modelforge_phi_ij = modelforge_schnet.core_module.schnet_representation_module.radial_symmetry_function_module( - d_ij.unsqueeze(1) / 10 + d_ij / 10 ) # NOTE: converting to nm - assert torch.allclose(schnetpack_phi_ij, modelforge_phi_ij) + assert torch.allclose(schnetpack_phi_ij, modelforge_phi_ij.unsqueeze(1)) phi_ij = schnetpack_phi_ij # ---------------------------------------- # # test cutoff @@ -556,7 +554,7 @@ def test_schnet_representation_implementation(): # test representation # --------------------------------------- # f_ij_mf = modelforge_schnet.core_module.schnet_representation_module.radial_symmetry_function_module( - d_ij.unsqueeze(1) / 10 + d_ij / 10 ) r_cut_ij_mf = ( modelforge_schnet.core_module.schnet_representation_module.cutoff_module( @@ -569,7 +567,7 @@ def test_schnet_representation_implementation(): f_ij_spk = schnetpack_schnet.radial_basis(d_ij) rcut_ij_spk = schnetpack_schnet.cutoff_fn(d_ij) - f_ij_mf_ = f_ij_mf.squeeze(1) + f_ij_mf_ = f_ij_mf r_cut_ij_mf_ = r_cut_ij_mf.squeeze(1) assert torch.allclose(f_ij_mf_, f_ij_spk) assert torch.allclose(r_cut_ij_mf_, rcut_ij_spk) @@ -637,14 +635,14 @@ def test_schnet_representation_implementation(): assert torch.allclose(v_spk, v_mf) # Check full pass - modelforge_results = modelforge_schnet.core_module._forward(schnet_nn_input_mf) + modelforge_results = modelforge_schnet.core_module.compute_properties(schnet_nn_input_mf) schnetpack_results = schnetpack_schnet(spk_input) assert ( schnetpack_results["scalar_representation"].shape - == modelforge_results["q"].shape + == modelforge_results["scalar_representation"].shape ) scalar_spk = schnetpack_results["scalar_representation"] - scalar_mf = modelforge_results["q"] + scalar_mf = modelforge_results["scalar_representation"] assert torch.allclose(scalar_spk, scalar_mf, atol=1e-4) From d08c9c64f816f7ce4ff2c0ded101d44097271446 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Tue, 9 Jul 2024 09:47:16 -0700 Subject: [PATCH 33/78] Clean spk test --- modelforge/tests/test_spk.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/modelforge/tests/test_spk.py b/modelforge/tests/test_spk.py index e50bcafd..8c58540a 100644 --- a/modelforge/tests/test_spk.py +++ b/modelforge/tests/test_spk.py @@ -494,7 +494,6 @@ def test_schnet_representation_implementation(): spk_input = input["spk_methane_input"] mf_nnp_input = input["modelforge_methane_input"] - schnetpack_results = schnetpack_schnet(spk_input) modelforge_schnet.input_preparation._input_checks(mf_nnp_input) pairlist_output = modelforge_schnet.input_preparation.prepare_inputs(mf_nnp_input) @@ -510,8 +509,6 @@ def test_schnet_representation_implementation(): assert torch.allclose(spk_input["_Rij"] / 10, schnet_nn_input_mf.r_ij, atol=1e-4) assert torch.allclose(spk_input["_idx_i"], schnet_nn_input_mf.pair_indices[0]) assert torch.allclose(spk_input["_idx_j"], schnet_nn_input_mf.pair_indices[1]) - idx_i = spk_input["_idx_i"] - idx_j = spk_input["_idx_j"] # ---------------------------------------- # # test radial symmetry function @@ -567,9 +564,8 @@ def test_schnet_representation_implementation(): f_ij_spk = schnetpack_schnet.radial_basis(d_ij) rcut_ij_spk = schnetpack_schnet.cutoff_fn(d_ij) - f_ij_mf_ = f_ij_mf r_cut_ij_mf_ = r_cut_ij_mf.squeeze(1) - assert torch.allclose(f_ij_mf_, f_ij_spk) + assert torch.allclose(f_ij_mf, f_ij_spk) assert torch.allclose(r_cut_ij_mf_, rcut_ij_spk) # ---------------------------------------- # From 0ff04f7598c7a5b27ab09568bdae5ee779d4e46c Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Tue, 9 Jul 2024 09:59:08 -0700 Subject: [PATCH 34/78] Working on spookynet Bernstein polynomials. Weird shape assertion fail. --- modelforge/potential/utils.py | 14 ++++++++++++-- modelforge/tests/test_spookynet.py | 4 ++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index 043e6946..a2233a4b 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -516,6 +516,7 @@ class ExponentialBernsteinPolynomialsCore(RadialBasisFunctionCore): def __init__(self, number_of_radial_basis_functions: int): + super().__init__(number_of_radial_basis_functions) logfactorial = np.zeros(number_of_radial_basis_functions) for i in range(2, number_of_radial_basis_functions): logfactorial[i] = logfactorial[i - 1] + np.log(i) @@ -543,6 +544,10 @@ def forward(self, nondimensionalized_distances: torch.Tensor) -> torch.Tensor: rbf (FloatTensor [N, num_basis_functions]): Values of the radial basis functions for the distances r. """ + print(f"{nondimensionalized_distances.shape=}") + print(f"{self.number_of_radial_basis_functions=}") + assert nondimensionalized_distances.ndim == 2 + assert nondimensionalized_distances.shape[1] == self.number_of_radial_basis_functions x = (self.logc + (self.n + 1) * nondimensionalized_distances + self.v * torch.log(-torch.expm1(nondimensionalized_distances))) print(f"{self.logc.shape=}") @@ -962,13 +967,18 @@ def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: class ExponentialBernsteinRadialBasisFunction(RadialBasisFunction): - def __init__(self, ini_alpha): + def __init__(self, number_of_radial_basis_functions, ini_alpha, dtype=torch.int64): """ ini_alpha (float): Initial value for scaling parameter alpha (alpha here is the reciprocal of alpha in the paper. The original default is 0.5/bohr, so we use 2 bohr). """ - self.ini_alpha = ini_alpha + super().__init__( + ExponentialBernsteinPolynomialsCore(number_of_radial_basis_functions), + trainable_prefactor=False, + dtype=dtype, + ) + self.alpha = ini_alpha def nondimensionalize_distances(self, d_ij: torch.Tensor) -> torch.Tensor: return -(d_ij.view(-1, 1) / self.alpha) diff --git a/modelforge/tests/test_spookynet.py b/modelforge/tests/test_spookynet.py index 34ad2e3d..8dc2180e 100644 --- a/modelforge/tests/test_spookynet.py +++ b/modelforge/tests/test_spookynet.py @@ -193,11 +193,11 @@ def test_spookynet_interaction_module_against_reference(): def test_spookynet_bernstein_polynomial_equivalence(): from spookynet.modules.exponential_bernstein_polynomials import ExponentialBernsteinPolynomials as RefExponentialBernsteinPolynomials - from modelforge.potential.utils import ExponentialBernsteinPolynomials as MfExponentialBernSteinPolynomials + from modelforge.potential.utils import ExponentialBernsteinRadialBasisFunction as MfExponentialBernSteinPolynomials num_basis_functions = 3 ref_exp_bernstein_polynomials = RefExponentialBernsteinPolynomials(num_basis_functions, exp_weighting=True) - mf_exp_bernstein_polynomials = MfExponentialBernSteinPolynomials(num_basis_functions) + mf_exp_bernstein_polynomials = MfExponentialBernSteinPolynomials(num_basis_functions, ini_alpha=1.0) N = 5 r_angstrom = torch.rand((N, 1)) From 3da23301ef74eaf4fad301069613373dc6b925c7 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Tue, 9 Jul 2024 16:47:09 -0700 Subject: [PATCH 35/78] Broadcast to number of radial basis functions in nondimensionalization of exponential Bernstein polynomials --- modelforge/potential/utils.py | 252 ++++++++++++++++++---------------- 1 file changed, 132 insertions(+), 120 deletions(-) diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index a2233a4b..ff784d73 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -43,7 +43,7 @@ class Metadata: F: torch.Tensor = torch.tensor([], dtype=torch.float32) def to( - self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None + self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None ): """Move all tensors in this instance to the specified device.""" if device: @@ -65,9 +65,9 @@ class BatchData: metadata: Metadata def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): self.nnp_input = self.nnp_input.to(device=device, dtype=dtype) self.metadata = self.metadata.to(device=device, dtype=dtype) @@ -86,7 +86,7 @@ def shared_config_prior(): def triple_by_molecule( - atom_pairs: torch.Tensor, + atom_pairs: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Input: indices for pairs of atoms that are close to each other. each pair only appear once, i.e. only one of the pairs (1, 2) and @@ -122,8 +122,8 @@ def cumsum_from_zero(input_: torch.Tensor) -> torch.Tensor: torch.tril_indices(m, m, -1, device=ai1.device).unsqueeze(1).expand(-1, n, -1) ) mask = ( - torch.arange(intra_pair_indices.shape[2], device=ai1.device) - < pair_sizes.unsqueeze(1) + torch.arange(intra_pair_indices.shape[2], device=ai1.device) + < pair_sizes.unsqueeze(1) ).flatten() sorted_local_index12 = intra_pair_indices.flatten(1, 2)[:, mask] sorted_local_index12 += cumsum_from_zero(counts).index_select(0, pair_indices) @@ -200,13 +200,13 @@ class Dense(nn.Linear): """ def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - activation: Optional[nn.Module] = None, - weight_init: Callable = xavier_uniform_, - bias_init: Callable = zeros_, + self, + in_features: int, + out_features: int, + bias: bool = True, + activation: Optional[nn.Module] = None, + weight_init: Callable = xavier_uniform_, + bias_init: Callable = zeros_, ): """ Args: @@ -268,7 +268,7 @@ def forward(self, d_ij: torch.Tensor): """ # Compute values of cutoff function input_cut = 0.5 * ( - torch.cos(d_ij * np.pi / self.cutoff) + 1.0 + torch.cos(d_ij * np.pi / self.cutoff) + 1.0 ) # NOTE: ANI adds 0.5 instead of 1. # Remove contributions beyond the cutoff radius input_cut *= (d_ij < self.cutoff).float() @@ -305,7 +305,9 @@ def forward(self, d_ij: torch.Tensor): zeros = torch.zeros_like(d_ij) r_ = torch.where(d_ij < self.cutoff, d_ij, zeros) # prevent nan in backprop return torch.where( - d_ij < self.cutoff, torch.exp(-(r_ ** 2) / ((self.cutoff - r_) * (self.cutoff + r_))), zeros + d_ij < self.cutoff, + torch.exp(-(r_**2) / ((self.cutoff - r_) * (self.cutoff + r_))), + zeros, ) @@ -346,13 +348,13 @@ class AngularSymmetryFunction(nn.Module): """ def __init__( - self, - max_distance: unit.Quantity, - min_distance: unit.Quantity, - number_of_gaussians_for_asf: int = 8, - angle_sections: int = 4, - trainable: bool = False, - dtype: Optional[torch.dtype] = None, + self, + max_distance: unit.Quantity, + min_distance: unit.Quantity, + number_of_gaussians_for_asf: int = 8, + angle_sections: int = 4, + trainable: bool = False, + dtype: Optional[torch.dtype] = None, ) -> None: """ Parameters @@ -485,11 +487,11 @@ class GaussianRadialBasisFunctionCore(RadialBasisFunctionCore): def forward(self, nondimensionalized_distances: torch.Tensor) -> torch.Tensor: assert nondimensionalized_distances.ndim == 2 assert ( - nondimensionalized_distances.shape[1] - == self.number_of_radial_basis_functions + nondimensionalized_distances.shape[1] + == self.number_of_radial_basis_functions ) - return torch.exp(-(nondimensionalized_distances ** 2)) + return torch.exp(-(nondimensionalized_distances**2)) class ExponentialBernsteinPolynomialsCore(RadialBasisFunctionCore): @@ -514,7 +516,6 @@ class ExponentialBernsteinPolynomialsCore(RadialBasisFunctionCore): x = infinity. """ - def __init__(self, number_of_radial_basis_functions: int): super().__init__(number_of_radial_basis_functions) logfactorial = np.zeros(number_of_radial_basis_functions) @@ -524,12 +525,11 @@ def __init__(self, number_of_radial_basis_functions: int): n = (number_of_radial_basis_functions - 1) - v logbinomial = logfactorial[-1] - logfactorial[v] - logfactorial[n] # register buffers and parameters - dtype = torch.float64 # TODO: make this a parameter + dtype = torch.float64 # TODO: make this a parameter self.logc = torch.tensor(logbinomial, dtype=dtype) self.n = torch.tensor(n, dtype=dtype) self.v = torch.tensor(v, dtype=dtype) - def forward(self, nondimensionalized_distances: torch.Tensor) -> torch.Tensor: """ Evaluates radial basis functions given distances @@ -547,9 +547,15 @@ def forward(self, nondimensionalized_distances: torch.Tensor) -> torch.Tensor: print(f"{nondimensionalized_distances.shape=}") print(f"{self.number_of_radial_basis_functions=}") assert nondimensionalized_distances.ndim == 2 - assert nondimensionalized_distances.shape[1] == self.number_of_radial_basis_functions - x = (self.logc + (self.n + 1) * nondimensionalized_distances - + self.v * torch.log(-torch.expm1(nondimensionalized_distances))) + assert ( + nondimensionalized_distances.shape[1] + == self.number_of_radial_basis_functions + ) + x = ( + self.logc + + (self.n + 1) * nondimensionalized_distances + + self.v * torch.log(-torch.expm1(nondimensionalized_distances)) + ) print(f"{self.logc.shape=}") return torch.exp(x) @@ -558,11 +564,11 @@ def forward(self, nondimensionalized_distances: torch.Tensor) -> torch.Tensor: class RadialBasisFunction(nn.Module, ABC): def __init__( - self, - radial_basis_function: RadialBasisFunctionCore, - dtype: torch.dtype, - prefactor: float = 1.0, - trainable_prefactor: bool = False, + self, + radial_basis_function: RadialBasisFunctionCore, + dtype: torch.dtype, + prefactor: float = 1.0, + trainable_prefactor: bool = False, ): super().__init__() if trainable_prefactor: @@ -613,14 +619,14 @@ class GaussianRadialBasisFunctionWithScaling(RadialBasisFunction): """ def __init__( - self, - number_of_radial_basis_functions: int, - max_distance: unit.Quantity, - min_distance: unit.Quantity = 0.0 * unit.nanometer, - dtype: Optional[torch.dtype] = None, - prefactor: float = 1.0, - trainable_prefactor: bool = False, - trainable_centers_and_scale_factors: bool = False, + self, + number_of_radial_basis_functions: int, + max_distance: unit.Quantity, + min_distance: unit.Quantity = 0.0 * unit.nanometer, + dtype: Optional[torch.dtype] = None, + prefactor: float = 1.0, + trainable_prefactor: bool = False, + trainable_centers_and_scale_factors: bool = False, ): """ Parameters @@ -680,10 +686,10 @@ def __init__( @staticmethod @abstractmethod def calculate_radial_basis_centers( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, ): """ NOTE: centers have units of nanometers @@ -693,10 +699,10 @@ def calculate_radial_basis_centers( @staticmethod @abstractmethod def calculate_radial_scale_factor( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, ): """ NOTE: radial scale factors have units of nanometers @@ -715,12 +721,12 @@ class SchnetRadialBasisFunction(GaussianRadialBasisFunctionWithScaling): """ def __init__( - self, - number_of_radial_basis_functions: int, - max_distance: unit.Quantity, - min_distance: unit.Quantity = 0.0 * unit.nanometer, - dtype: Optional[torch.dtype] = None, - trainable_centers_and_scale_factors: bool = False, + self, + number_of_radial_basis_functions: int, + max_distance: unit.Quantity, + min_distance: unit.Quantity = 0.0 * unit.nanometer, + dtype: Optional[torch.dtype] = None, + trainable_centers_and_scale_factors: bool = False, ): """ Parameters @@ -747,10 +753,10 @@ def __init__( @staticmethod def calculate_radial_basis_centers( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, ): return torch.linspace( _min_distance_in_nanometer, @@ -761,10 +767,10 @@ def calculate_radial_basis_centers( @staticmethod def calculate_radial_scale_factor( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, ): scale_factors = torch.linspace( _min_distance_in_nanometer, @@ -773,8 +779,8 @@ def calculate_radial_scale_factor( ) widths = ( - torch.abs(scale_factors[1] - scale_factors[0]) - * torch.ones_like(scale_factors) + torch.abs(scale_factors[1] - scale_factors[0]) + * torch.ones_like(scale_factors) ).to(dtype) scale_factors = math.sqrt(2) * widths @@ -787,12 +793,12 @@ class AniRadialBasisFunction(GaussianRadialBasisFunctionWithScaling): """ def __init__( - self, - number_of_radial_basis_functions, - max_distance: unit.Quantity, - min_distance: unit.Quantity = 0.0 * unit.nanometer, - dtype: torch.dtype = torch.float32, - trainable_centers_and_scale_factors: bool = False, + self, + number_of_radial_basis_functions, + max_distance: unit.Quantity, + min_distance: unit.Quantity = 0.0 * unit.nanometer, + dtype: torch.dtype = torch.float32, + trainable_centers_and_scale_factors: bool = False, ): """ Parameters @@ -820,10 +826,10 @@ def __init__( @staticmethod def calculate_radial_basis_centers( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, ): centers = torch.linspace( _min_distance_in_nanometer, @@ -835,10 +841,10 @@ def calculate_radial_basis_centers( @staticmethod def calculate_radial_scale_factor( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, ): # ANI uses a predefined scaling factor scale_factors = torch.full( @@ -853,12 +859,12 @@ class PhysNetRadialBasisFunction(RadialBasisFunction): """ def __init__( - self, - number_of_radial_basis_functions: int, - max_distance: unit.Quantity, - min_distance: unit.Quantity = 0.0 * unit.nanometer, - dtype: torch.dtype = torch.float32, - trainable_centers_and_scale_factors: bool = False, + self, + number_of_radial_basis_functions: int, + max_distance: unit.Quantity, + min_distance: unit.Quantity = 0.0 * unit.nanometer, + dtype: torch.dtype = torch.float32, + trainable_centers_and_scale_factors: bool = False, ): """ Parameters @@ -905,10 +911,10 @@ def __init__( @staticmethod def calculate_radial_basis_centers( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, ): # initialize centers according to the default values in PhysNet # (see mu_k in Figure 2 caption of https://pubs.acs.org/doi/10.1021/acs.jctc.9b00181) @@ -928,10 +934,10 @@ def calculate_radial_basis_centers( @staticmethod def calculate_radial_scale_factor( - number_of_radial_basis_functions, - _max_distance_in_nanometer, - _min_distance_in_nanometer, - dtype, + number_of_radial_basis_functions, + _max_distance_in_nanometer, + _min_distance_in_nanometer, + dtype, ): # initialize according to the default values in PhysNet (see beta_k in Figure 2 caption) # NOTES: @@ -942,13 +948,13 @@ def calculate_radial_scale_factor( return torch.full( (number_of_radial_basis_functions,), ( - 2 - * ( - 1 - - math.exp( + 2 + * ( + 1 + - math.exp( 10 * (-_max_distance_in_nanometer + _min_distance_in_nanometer) ) - ) + ) ) / number_of_radial_basis_functions, dtype=dtype, @@ -960,8 +966,8 @@ def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: # nanometers, so we multiply by 10/nanometer return ( - torch.exp((-distances + self._min_distance_in_nanometer) * 10) - - self.radial_basis_centers + torch.exp((-distances + self._min_distance_in_nanometer) * 10) + - self.radial_basis_centers ) / self.radial_scale_factor @@ -969,9 +975,9 @@ class ExponentialBernsteinRadialBasisFunction(RadialBasisFunction): def __init__(self, number_of_radial_basis_functions, ini_alpha, dtype=torch.int64): """ - ini_alpha (float): - Initial value for scaling parameter alpha (alpha here is the reciprocal of alpha in the paper. The original - default is 0.5/bohr, so we use 2 bohr). + ini_alpha (float): + Initial value for scaling parameter alpha (alpha here is the reciprocal of alpha in the paper. The original + default is 0.5/bohr, so we use 2 bohr). """ super().__init__( ExponentialBernsteinPolynomialsCore(number_of_radial_basis_functions), @@ -979,13 +985,19 @@ def __init__(self, number_of_radial_basis_functions, ini_alpha, dtype=torch.int6 dtype=dtype, ) self.alpha = ini_alpha + def nondimensionalize_distances(self, d_ij: torch.Tensor) -> torch.Tensor: - return -(d_ij.view(-1, 1) / self.alpha) + return -( + d_ij.broadcast_to( + (len(d_ij), self.radial_basis_function.number_of_radial_basis_functions) + ) + / self.alpha + ) def pair_list( - atomic_subsystem_indices: torch.Tensor, - only_unique_pairs: bool = False, + atomic_subsystem_indices: torch.Tensor, + only_unique_pairs: bool = False, ) -> torch.Tensor: """Compute all pairs of atoms and their distances. @@ -1026,7 +1038,7 @@ def pair_list( # filter pairs to only keep those belonging to the same molecule same_molecule_mask = ( - atomic_subsystem_indices[i_indices] == atomic_subsystem_indices[j_indices] + atomic_subsystem_indices[i_indices] == atomic_subsystem_indices[j_indices] ) # Apply mask to get final pair indices @@ -1039,9 +1051,9 @@ def pair_list( return pair_indices.to(device) def forward( - self, - coordinates: torch.Tensor, # in nanometer - atomic_subsystem_indices: torch.Tensor, + self, + coordinates: torch.Tensor, # in nanometer + atomic_subsystem_indices: torch.Tensor, ) -> torch.Tensor: """Compute all pairs of atoms and their distances. @@ -1073,11 +1085,11 @@ def forward( def scatter_softmax( - src: torch.Tensor, - index: torch.Tensor, - dim: int, - dim_size: Optional[int] = None, - device: Optional[torch.device] = None, + src: torch.Tensor, + index: torch.Tensor, + dim: int, + dim_size: Optional[int] = None, + device: Optional[torch.device] = None, ) -> torch.Tensor: """ Softmax operation over all values in :attr:`src` tensor that share indices @@ -1111,7 +1123,7 @@ def scatter_softmax( assert dim >= 0, f"dim must be non-negative, got {dim}" assert ( - dim < src.dim() + dim < src.dim() ), f"dim must be less than the number of dimensions of src {src.dim()}, got {dim}" out_shape = [ From e24b6fab078c2f6b09547df10b1ddfca809b4b36 Mon Sep 17 00:00:00 2001 From: wiederm Date: Wed, 10 Jul 2024 10:25:21 +0200 Subject: [PATCH 36/78] update toml --- .../potential_defaults/ani2x_without_ase.toml | 21 --------------- .../potential_defaults/painn_without_ase.toml | 22 --------------- .../physnet_without_ase.toml | 21 --------------- .../potential_defaults/sake_without_ase.toml | 21 --------------- .../tests/data/potential_defaults/schnet.toml | 18 +++---------- .../schnet_without_ase.toml | 22 --------------- modelforge/tests/data/training/default.toml | 27 ------------------- .../tests/data/training_defaults/default.toml | 8 +++--- 8 files changed, 8 insertions(+), 152 deletions(-) delete mode 100644 modelforge/tests/data/potential_defaults/ani2x_without_ase.toml delete mode 100644 modelforge/tests/data/potential_defaults/painn_without_ase.toml delete mode 100644 modelforge/tests/data/potential_defaults/physnet_without_ase.toml delete mode 100644 modelforge/tests/data/potential_defaults/sake_without_ase.toml delete mode 100644 modelforge/tests/data/potential_defaults/schnet_without_ase.toml delete mode 100644 modelforge/tests/data/training/default.toml diff --git a/modelforge/tests/data/potential_defaults/ani2x_without_ase.toml b/modelforge/tests/data/potential_defaults/ani2x_without_ase.toml deleted file mode 100644 index 0c8475e3..00000000 --- a/modelforge/tests/data/potential_defaults/ani2x_without_ase.toml +++ /dev/null @@ -1,21 +0,0 @@ -[potential] -model_name = "ANI2x" - -[potential.potential_parameter] -angle_sections = 4 -radial_max_distance = "5.1 angstrom" -radial_min_distance = "0.8 angstrom" -number_of_radial_basis_functions = 16 -angular_max_distance = "3.5 angstrom" -angular_min_distance = "0.8 angstrom" -angular_dist_divisions = 8 - -processing_operation = [ - { in = [ - "E_i", - ], out = "E_i", step = "normalization", mean = "E_i_mean", stddev = "E_i_mean" }, - -] -readout_operation = [ - { step = "from_atom_to_molecule", mode = 'sum', in = 'E_i', index_key = 'atomic_subsystem_indices', out = 'E' }, -] diff --git a/modelforge/tests/data/potential_defaults/painn_without_ase.toml b/modelforge/tests/data/potential_defaults/painn_without_ase.toml deleted file mode 100644 index 0c7c3302..00000000 --- a/modelforge/tests/data/potential_defaults/painn_without_ase.toml +++ /dev/null @@ -1,22 +0,0 @@ -[potential] -model_name = "PaiNN" - -[potential.potential_parameter] - -max_Z = 101 -number_of_atom_features = 32 -number_of_radial_basis_functions = 20 -cutoff = "5.0 angstrom" -number_of_interaction_modules = 3 -shared_interactions = false -shared_filters = false - -processing_operation = [ - { in = [ - "E_i", - ], out = "E_i", step = "normalization", mean = "E_i_mean", stddev = "E_i_mean" }, - -] -readout_operation = [ - { step = "from_atom_to_molecule", mode = 'sum', in = 'E_i', index_key = 'atomic_subsystem_indices', out = 'E' }, -] diff --git a/modelforge/tests/data/potential_defaults/physnet_without_ase.toml b/modelforge/tests/data/potential_defaults/physnet_without_ase.toml deleted file mode 100644 index eb0df83e..00000000 --- a/modelforge/tests/data/potential_defaults/physnet_without_ase.toml +++ /dev/null @@ -1,21 +0,0 @@ -[potential] -model_name = "PhysNet" - -[potential.potential_parameter] - -max_Z = 101 -number_of_atom_features = 64 -number_of_radial_basis_functions = 16 -cutoff = "5.0 angstrom" -number_of_interaction_residual = 3 -number_of_modules = 5 - -processing_operation = [ - { in = [ - "E_i", - ], out = "E_i", step = "normalization", mean = "E_i_mean", stddev = "E_i_mean" }, - -] -readout_operation = [ - { step = "from_atom_to_molecule", mode = 'sum', in = 'E_i', index_key = 'atomic_subsystem_indices', out = 'E' }, -] diff --git a/modelforge/tests/data/potential_defaults/sake_without_ase.toml b/modelforge/tests/data/potential_defaults/sake_without_ase.toml deleted file mode 100644 index 97428c74..00000000 --- a/modelforge/tests/data/potential_defaults/sake_without_ase.toml +++ /dev/null @@ -1,21 +0,0 @@ -[potential] -model_name = "SAKE" - -[potential.potential_parameter] - -max_Z = 101 -number_of_atom_features = 64 -number_of_radial_basis_functions = 50 -cutoff = "5.0 angstrom" -number_of_interaction_modules = 6 -number_of_spatial_attention_heads = 4 - -processing_operation = [ - { in = [ - "E_i", - ], out = "E_i", step = "normalization", mean = "E_i_mean", stddev = "E_i_mean" }, - -] -readout_operation = [ - { step = "from_atom_to_molecule", mode = 'sum', in = 'E_i', index_key = 'atomic_subsystem_indices', out = 'E' }, -] diff --git a/modelforge/tests/data/potential_defaults/schnet.toml b/modelforge/tests/data/potential_defaults/schnet.toml index 42cf72e3..7289477f 100644 --- a/modelforge/tests/data/potential_defaults/schnet.toml +++ b/modelforge/tests/data/potential_defaults/schnet.toml @@ -11,18 +11,8 @@ number_of_interaction_modules = 3 number_of_filters = 32 shared_interactions = false -processing_operation = [ - { in = [ - "E_i", - ], out = "E_i", step = "normalization", mean = "E_i_mean", stddev = "E_i_mean" }, - { in = [ - "atomic_numbers", - 'atomic_subsystem_indices', - ], out = "ase", step = "calculate_ase" }, +[potential.postprocessing] +[potential.postprocessing.energy] +normalize = true +from_atom_to_molecule_reduction = true -] -readout_operation = [ - { step = "from_atom_to_molecule", mode = 'sum', in = 'E_i', index_key = 'atomic_subsystem_indices', out = 'E' }, - { step = "from_atom_to_molecule", mode = 'sum', in = 'ase', index_key = 'atomic_subsystem_indices', out = 'mse' }, - -] diff --git a/modelforge/tests/data/potential_defaults/schnet_without_ase.toml b/modelforge/tests/data/potential_defaults/schnet_without_ase.toml deleted file mode 100644 index 06e81fef..00000000 --- a/modelforge/tests/data/potential_defaults/schnet_without_ase.toml +++ /dev/null @@ -1,22 +0,0 @@ -[potential] -model_name = "SchNet" - -[potential.potential_parameter] - -max_Z = 101 -number_of_atom_features = 32 -number_of_radial_basis_functions = 20 -cutoff = "5.0 angstrom" -number_of_interaction_modules = 3 -number_of_filters = 32 -shared_interactions = false - -processing_operation = [ - { in = [ - "E_i", - ], out = "E_i", step = "normalization", mean = "E_i_mean", stddev = "E_i_mean" }, - -] -readout_operation = [ - { step = "from_atom_to_molecule", mode = 'sum', in = 'E_i', index_key = 'atomic_subsystem_indices', out = 'E' }, -] diff --git a/modelforge/tests/data/training/default.toml b/modelforge/tests/data/training/default.toml deleted file mode 100644 index efa2d479..00000000 --- a/modelforge/tests/data/training/default.toml +++ /dev/null @@ -1,27 +0,0 @@ -[training] - -training_parameter.lr = 1e-3 -training_parameter.lr_scheduler_config.frequency = 1 -training_parameter.lr_scheduler_config.mode = "min" -training_parameter.lr_scheduler_config.factor = 0.1 -training_parameter.lr_scheduler_config.patience = 10 -training_parameter.lr_scheduler_config.cooldown = 5 -training_parameter.lr_scheduler_config.min_lr = 1e-8 -training_parameter.lr_scheduler_config.threshold = 0.1 -training_parameter.lr_scheduler_config.threshold_mode = "abs" -training_parameter.lr_scheduler_config.monitor = "val/energy/rmse" -training_parameter.lr_scheduler_config.interval = "epoch" - -loss_parameter.loss_type = "EnergyAndForceLoss" -loss_parameter.include_force = true -loss_parameter.force_weight = 1.0 -loss_parameter.energy_weight = 1.0 - - -early_stopping.monitor = "val/energy/rmse" -early_stopping.min_delta = 0.01 -early_stopping.patience = 50 -early_stopping.verbose = true - -stochastic_weight_averaging_config.swa_epoch_start = 50 -stochastic_weight_averaging_config.swa_lrs = 1e-2 diff --git a/modelforge/tests/data/training_defaults/default.toml b/modelforge/tests/data/training_defaults/default.toml index 299ec550..d6211ac8 100644 --- a/modelforge/tests/data/training_defaults/default.toml +++ b/modelforge/tests/data/training_defaults/default.toml @@ -24,10 +24,10 @@ monitor = "val/energy/rmse" interval = "epoch" [training.training_parameter.loss_parameter] -loss_type = "EnergyAndForceLoss" -include_force = true -force_weight = 0.99 -energy_weight = 0.01 +loss_property = ['energy', 'force'] +[loss_parameter.weight] +energy = 0.999 +force = 0.001 [training.early_stopping] verbose = true From 23ea05ffbad0ecd9be8c3afc04efb91fa407f91c Mon Sep 17 00:00:00 2001 From: wiederm Date: Wed, 10 Jul 2024 12:46:03 +0200 Subject: [PATCH 37/78] transition to new version of postprocessing --- modelforge/potential/models.py | 164 +++++++++++++++-------------- modelforge/potential/processing.py | 9 +- modelforge/potential/schnet.py | 6 +- modelforge/tests/test_models.py | 9 +- modelforge/train/training.py | 2 - 5 files changed, 94 insertions(+), 96 deletions(-) diff --git a/modelforge/potential/models.py b/modelforge/potential/models.py index ad1cab97..133a3501 100644 --- a/modelforge/potential/models.py +++ b/modelforge/potential/models.py @@ -525,8 +525,7 @@ class NeuralNetworkPotentialFactory: def generate_model( *, use: Literal["training", "inference"], - model_type: Literal["ANI2x", "SchNet", "PaiNN", "SAKE", "PhysNet"], - model_parameter: Dict[str, Union[int, float, str, List[str]]], + model_parameter: Dict[str, Union[str, Any]], simulation_environment: Literal["PyTorch", "JAX"] = "PyTorch", training_parameter: Optional[Dict[str, Any]] = None, dataset_statistic: Optional[Dict[str, float]] = None, @@ -538,8 +537,6 @@ def generate_model( ---------- use : str The use case for the NNP instance. - model_name : str - The type of NNP to instantiate. simulation_environment : str The ML framework to use, either 'PyTorch' or 'JAX'. nnp_parameters : dict, optional @@ -564,12 +561,13 @@ def generate_model( from modelforge.train.training import TrainingAdapter log.debug(f"{training_parameter=}") + log.debug(f"{model_parameter=}") # get model + model_type = model_parameter["model_name"] nnp_class: Type = _Implemented_NNPs.get_neural_network_class(model_type) - if nnp_class is None: - raise NotImplementedError(f"NNP type {model_type} is not implemented.") # add modifications to NNP if requested + if use == "training": if simulation_environment == "JAX": log.warning( @@ -587,7 +585,11 @@ def generate_model( # for training the `model_name` might have been set if "model_name" in model_parameter: del model_parameter["model_name"] - model = nnp_class(**model_parameter, dataset_statistic=dataset_statistic) + model = nnp_class( + **model_parameter["core_parameter"], + postprocessing_parameter=model_parameter["postprocessing"], + dataset_statistic=dataset_statistic, + ) if simulation_environment == "JAX": return PyTorch2JAXConverter().convert_to_jax_model(model) else: @@ -687,37 +689,62 @@ def _input_checks(self, data: Union[NNPInput, NamedTuple]): assert data.positions.shape == torch.Size([nr_of_atoms, 3]) -from torch.nn import ModuleList +from torch.nn import ModuleDict class PostProcessing(torch.nn.Module): + + _SUPPORTED_PROPERTIES = ["energy"] + _SUPPORTED_OPERATIONS = ["normalize", "from_atom_to_molecule_reduction"] + def __init__( self, - processing_operation: List[Dict[str, str]], - readout_operation: List[Dict[str, str]], - dataset_statistic, + postprocessing_parameter: Dict[str, Dict[str, bool]], + dataset_statistic: Dict[str, Dict[str, float]], ): """ Parameters ---------- - model_parameter : Dict[str, Any] - The model parameters. + postprocessing_parameter: Dict[str, Dict[str, bool]] # TODO: update dataset_statistic : Dict[str, float] The dataset statistics. """ super().__init__() + self._registered_properties: List[str] = [] + + # operations that use nn.Sequence to pass the output of the model to the next + self.registered_chained_operations = ModuleDict() + # operations that don't requre any nn.Sequence + self.registered_independent_operations = ModuleDict() + + self.dataset_statistic = dataset_statistic + self._initialize_postprocessing( - processing_operation, - readout_operation, - dataset_statistic, + postprocessing_parameter, ) + def _get_mean_and_stddev_of_dataset(self) -> Tuple[float, float]: + + if self.dataset_statistic is None: + mean = 0.0 + stddev = 1.0 + log.warning( + f"No mean and stddev provided for dataset. Setting to default value {mean=} and {stddev=}!" + ) + else: + atomic_energies_stats = self.dataset_statistic["atomic_energies_stats"] + mean = unit.Quantity(atomic_energies_stats["mean"]).m_as( + unit.kilojoule_per_mole + ) + stddev = unit.Quantity(atomic_energies_stats["stddev"]).m_as( + unit.kilojoule_per_mole + ) + return mean, stddev + def _initialize_postprocessing( self, - processing_operation: List[Dict[str, str]], - readout_operation: List[Dict[str, str]], - dataset_statistic, + postprocessing_parameter: Dict[str, Dict[str, bool]], ): from .processing import ( FromAtomToMoleculeReduction, @@ -725,63 +752,49 @@ def _initialize_postprocessing( CalculateAtomicSelfEnergy, ) - # initialize per atom processing - work_to_be_done_per_property = [] - props = [] - for proc in processing_operation: - if proc["step"] == "normalization": - if dataset_statistic is None: - log.warning( - f"No mean and stddev provided for property {proc['in']}. Setting to default values!" - ) - mean = 0.0 - stddev = 1.0 - else: - atomic_energies_stats = dataset_statistic["atomic_energies_stats"] - mean = unit.Quantity(atomic_energies_stats[proc["mean"]]).m_as( - unit.kilojoule_per_mole - ) - stddev = unit.Quantity(atomic_energies_stats[proc["stddev"]]).m_as( - unit.kilojoule_per_mole + # register properties + for property in postprocessing_parameter: + if property.lower() in self._SUPPORTED_PROPERTIES: + self._registered_properties.append(property.lower()) + else: + raise ValueError( + f"Property {property} is not supported. Supported properties are {self._SUPPORTED_PROPERTIES}" + ) + + # register operations + for property, operations in postprocessing_parameter.items(): + postprocessing_sequence = torch.nn.Sequential() + + for operation in operations: + if operation.lower() == "normalize": + mean, stddev = self._get_mean_and_stddev_of_dataset() + postprocessing_sequence.append( + ScaleValues(mean=mean, stddev=stddev) ) - operation = ScaleValues(mean=mean, stddev=stddev) - work_to_be_done_per_property.append(operation) - props.append(proc) - - elif proc["step"] == "calculate_ase": - if dataset_statistic is None: - raise RuntimeError( - "No dataset statistics provided for ASE calculation. Skipping!" + elif operation.lower() == "from_atom_to_molecule_reduction": + postprocessing_sequence.append(FromAtomToMoleculeReduction()) + elif operation.lower() == "calculate_atomic_self_energy": + atomic_self_energies = self.dataset_statistic[ + "atomic_self_energies" + ] + postprocessing_sequence.append( + CalculateAtomicSelfEnergy(atomic_self_energies)() ) + else: - atomic_self_energies = dataset_statistic["atomic_self_energies"] - operation = CalculateAtomicSelfEnergy(atomic_self_energies) - work_to_be_done_per_property.append(operation) - - props.append(proc) - - self.per_atom_operations = ModuleList(work_to_be_done_per_property) - self.per_atom_operations_prop = props - - # initialize per molecule reduction - work_to_be_done_per_property = [] - props = [] - for proc in readout_operation: - if proc["step"] == "from_atom_to_molecule": - operation = FromAtomToMoleculeReduction( - reduction_mode=proc["mode"], - ) - work_to_be_done_per_property.append(operation) - props.append(proc) + raise ValueError( + f"Operation {operation} is not implemented. Supported properties are {self._SUPPORTED_OPERATIONS}" + ) - self.readout_operation = ModuleList(work_to_be_done_per_property) - self.readout_prop = props + self.registered_chained_operations[property] = postprocessing_sequence def forward(self, outputs: Dict[str, torch.Tensor]): """ Perform post-processing operations on per-atom properties and reduction operations to calculate per-molecule properties. """ - for property, processing in zip(self.per_atom_operations_prop, self.per_atom_operations): + for property, processing in zip( + self.per_atom_operations_prop, self.per_atom_operations + ): inputs = [outputs[in_key] for in_key in property["in"]] outputs[property["out"]] = processing(*inputs) @@ -800,11 +813,7 @@ def forward(self, outputs: Dict[str, torch.Tensor]): class BaseNetwork(Module): def __init__( - self, - *, - processing_operation: Dict[str, str], - readout_operation: Dict[str, str], - dataset_statistic: Optional[Dict[str, float]] = None, + self, *, postprocessing_parameter: Dict[str, Dict[str, bool]], dataset_statistic ): """ The BaseNetwork wraps the input preparation (including pairlist calculation, d_ij and r_ij calculation), the actual model as well as the output preparation in a wrapper class. @@ -813,19 +822,12 @@ def __init__( Parameters ---------- - processing_operation : List[Dict[str, str]] - A list of dictionaries containing the processing steps to be applied to the model output. - readout_operation : List[Dict[str, str]] - A list of dictionaries containing the readout_operation steps to be applied to the model output. - dataset_statistic : Dict[str, float], optional - A dictionary containing the dataset statistics for the model, by default None. + postprocessing_parameter : Dict[str, Dict[str, bool]] # TODO: update """ super().__init__() self.postprocessing = PostProcessing( - processing_operation=processing_operation, - dataset_statistic=dataset_statistic, - readout_operation=readout_operation, + postprocessing_parameter, dataset_statistic ) def load_state_dict( diff --git a/modelforge/potential/processing.py b/modelforge/potential/processing.py index f1f13927..99293c1f 100644 --- a/modelforge/potential/processing.py +++ b/modelforge/potential/processing.py @@ -199,10 +199,10 @@ def __init__(self, atomic_self_energies) -> None: # if values in atomic_self_energies are strings convert them to kJ/mol if isinstance(list(atomic_self_energies.values())[0], str): atomic_self_energies = { - key: unit.Quantity(value) - for key, value in atomic_self_energies.items() + key: unit.Quantity(value) for key, value in atomic_self_energies.items() } self.atomic_self_energies = AtomicSelfEnergies(atomic_self_energies) + self.reduction = FromAtomToMoleculeReduction(reduction_mode="sum") def forward( self, @@ -235,4 +235,7 @@ def forward( # contains the atomic self energy for each atomic number ase_tensor = ase_tensor_for_indexing[atomic_numbers] - return ase_tensor + # then we need to sum over atoms to get the molecular self energy + per_molecule_self_energy = self.reduction(ase_tensor, atomic_subsystem_indices) + + return per_molecule_self_energy diff --git a/modelforge/potential/schnet.py b/modelforge/potential/schnet.py index ad70d384..b7f50a53 100644 --- a/modelforge/potential/schnet.py +++ b/modelforge/potential/schnet.py @@ -376,8 +376,7 @@ def __init__( cutoff: unit.Quantity, number_of_filters: int, shared_interactions: bool, - processing_operation: List[Dict[str, str]], - readout_operation: List[Dict[str, str]], + postprocessing_parameter: Dict[str, Dict[str, bool]], dataset_statistic: Optional[Dict[str, float]] = None, ) -> None: """ @@ -400,8 +399,7 @@ def __init__( """ super().__init__( dataset_statistic=dataset_statistic, - processing_operation=processing_operation, - readout_operation=readout_operation, + postprocessing_parameter=postprocessing_parameter, ) from modelforge.utils.units import _convert diff --git a/modelforge/tests/test_models.py b/modelforge/tests/test_models.py index 4136462d..59799cba 100644 --- a/modelforge/tests/test_models.py +++ b/modelforge/tests/test_models.py @@ -370,19 +370,16 @@ def test_forward_pass( nr_of_mols = nnp_input.atomic_subsystem_indices.unique().shape[0] # read default parameters - config = load_configs(f"{model_name.lower()}_without_ase", "qm9") + config = load_configs(f"{model_name.lower()}", "qm9") # Extract parameters - potential_parameter = config["potential"].get("potential_parameter", {}) - # Setup loss - from modelforge.train.training import return_toml_config + model_parameter = config["potential"] # test the forward pass through each of the models model = NeuralNetworkPotentialFactory.generate_model( use="inference", - model_type=model_name, simulation_environment=simulation_environment, - model_parameter=potential_parameter, + model_parameter=model_parameter, ) if "JAX" in str(type(model)): nnp_input = nnp_input.as_jax_namedtuple() diff --git a/modelforge/train/training.py b/modelforge/train/training.py index 2155aa24..95726b24 100644 --- a/modelforge/train/training.py +++ b/modelforge/train/training.py @@ -281,8 +281,6 @@ def __init__( "NNP name must be specified in nnp_parameters with key 'model_name'." ) nnp_class: Type = _Implemented_NNPs.get_neural_network_class(model_name) - if nnp_class is None: - raise ValueError(f"Specified NNP name '{model_name}' is not implemented.") self.model = nnp_class( **model_parameter_, From 9632a7d81f10ed46a9d2674a281ba2ab96a33e85 Mon Sep 17 00:00:00 2001 From: wiederm Date: Wed, 10 Jul 2024 16:26:43 +0200 Subject: [PATCH 38/78] updated toml --- modelforge/tests/data/potential_defaults/schnet.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/modelforge/tests/data/potential_defaults/schnet.toml b/modelforge/tests/data/potential_defaults/schnet.toml index 7289477f..59c28487 100644 --- a/modelforge/tests/data/potential_defaults/schnet.toml +++ b/modelforge/tests/data/potential_defaults/schnet.toml @@ -1,8 +1,7 @@ [potential] model_name = "SchNet" -[potential.potential_parameter] - +[potential.core_parameter] max_Z = 101 number_of_atom_features = 32 number_of_radial_basis_functions = 20 From dc63f1319f044f01a107b1de165a0e026acf2932 Mon Sep 17 00:00:00 2001 From: wiederm Date: Wed, 10 Jul 2024 17:25:23 +0200 Subject: [PATCH 39/78] update names, intial postprocessing implementation --- modelforge/potential/ani.py | 2 +- modelforge/potential/models.py | 75 +++++++++++-------- modelforge/potential/painn.py | 2 +- modelforge/potential/physnet.py | 6 +- modelforge/potential/processing.py | 37 +++++---- modelforge/potential/sake.py | 5 +- modelforge/potential/schnet.py | 2 +- .../tests/data/potential_defaults/ani2x.toml | 4 +- .../tests/data/potential_defaults/painn.toml | 4 +- .../tests/data/potential_defaults/sake.toml | 4 +- .../tests/data/potential_defaults/schnet.toml | 3 +- modelforge/tests/test_models.py | 20 +++-- modelforge/tests/test_painn.py | 7 +- modelforge/tests/test_sake.py | 2 +- modelforge/tests/test_schnet.py | 2 +- modelforge/tests/test_spk.py | 4 +- 16 files changed, 102 insertions(+), 77 deletions(-) diff --git a/modelforge/potential/ani.py b/modelforge/potential/ani.py index ba9cb3c8..ac1262bf 100644 --- a/modelforge/potential/ani.py +++ b/modelforge/potential/ani.py @@ -540,7 +540,7 @@ def compute_properties(self, data: AniNeuralNetworkData) -> Dict[str, torch.Tens E_i = self.interaction_modules(representation) return { - "E_i": E_i, + per_atom_energy: E_i, "atomic_subsystem_indices": data.atomic_subsystem_indices, } diff --git a/modelforge/potential/models.py b/modelforge/potential/models.py index 133a3501..a39e64e1 100644 --- a/modelforge/potential/models.py +++ b/modelforge/potential/models.py @@ -562,12 +562,8 @@ def generate_model( log.debug(f"{training_parameter=}") log.debug(f"{model_parameter=}") - # get model - model_type = model_parameter["model_name"] - nnp_class: Type = _Implemented_NNPs.get_neural_network_class(model_type) - - # add modifications to NNP if requested + # obtain model for training if use == "training": if simulation_environment == "JAX": log.warning( @@ -580,11 +576,10 @@ def generate_model( dataset_statistic=dataset_statistic, ) return model + # obtain model for inference elif use == "inference": - # if this model_parameter dictionary ahs already been used - # for training the `model_name` might have been set - if "model_name" in model_parameter: - del model_parameter["model_name"] + model_type = model_parameter["model_name"] + nnp_class: Type = _Implemented_NNPs.get_neural_network_class(model_type) model = nnp_class( **model_parameter["core_parameter"], postprocessing_parameter=model_parameter["postprocessing"], @@ -694,7 +689,7 @@ def _input_checks(self, data: Union[NNPInput, NamedTuple]): class PostProcessing(torch.nn.Module): - _SUPPORTED_PROPERTIES = ["energy"] + _SUPPORTED_PROPERTIES = ["per_atom_energy"] _SUPPORTED_OPERATIONS = ["normalize", "from_atom_to_molecule_reduction"] def __init__( @@ -712,7 +707,7 @@ def __init__( super().__init__() self._registered_properties: List[str] = [] - + # operations that use nn.Sequence to pass the output of the model to the next self.registered_chained_operations = ModuleDict() # operations that don't requre any nn.Sequence @@ -725,7 +720,7 @@ def __init__( ) def _get_mean_and_stddev_of_dataset(self) -> Tuple[float, float]: - + if self.dataset_statistic is None: mean = 0.0 stddev = 1.0 @@ -764,15 +759,27 @@ def _initialize_postprocessing( # register operations for property, operations in postprocessing_parameter.items(): postprocessing_sequence = torch.nn.Sequential() + prostprocessing_sequence_names = [] for operation in operations: - if operation.lower() == "normalize": + if operation.lower() == "normalize" and property == "per_atom_energy": mean, stddev = self._get_mean_and_stddev_of_dataset() postprocessing_sequence.append( ScaleValues(mean=mean, stddev=stddev) ) - elif operation.lower() == "from_atom_to_molecule_reduction": - postprocessing_sequence.append(FromAtomToMoleculeReduction()) + prostprocessing_sequence_names.append(operation) + # check if also reduction is requested + for operation in operations: + if operation.lower() == "from_atom_to_molecule_reduction": + postprocessing_sequence.append( + FromAtomToMoleculeReduction( + per_atom_property_name="per_atom_energy", + index_name="atomic_subsystem_indices", + output_name="per_molecule_energy", + ) + ) + prostprocessing_sequence_names.append(operation) + elif operation.lower() == "calculate_atomic_self_energy": atomic_self_energies = self.dataset_statistic[ "atomic_self_energies" @@ -780,32 +787,36 @@ def _initialize_postprocessing( postprocessing_sequence.append( CalculateAtomicSelfEnergy(atomic_self_energies)() ) + prostprocessing_sequence_names.append(operation) - else: - raise ValueError( - f"Operation {operation} is not implemented. Supported properties are {self._SUPPORTED_OPERATIONS}" + postprocessing_sequence.append( + FromAtomToMoleculeReduction( + per_atom_property_name="ase_tensor", + index_name="atomic_subsystem_indices", + output_name="per_molecule_self_energy", + ) ) + elif ( + operation.lower() == "from_atom_to_molecule_reduction" + and operation.lower() not in prostprocessing_sequence_names + ): + postprocessing_sequence.append(FromAtomToMoleculeReduction()) + prostprocessing_sequence_names.append(operation) + + log.debug(prostprocessing_sequence_names) + self.registered_chained_operations[property] = postprocessing_sequence def forward(self, outputs: Dict[str, torch.Tensor]): """ - Perform post-processing operations on per-atom properties and reduction operations to calculate per-molecule properties. + Perform post-processing operations for all registered properties. """ - for property, processing in zip( - self.per_atom_operations_prop, self.per_atom_operations - ): - inputs = [outputs[in_key] for in_key in property["in"]] - outputs[property["out"]] = processing(*inputs) - - # if per atom properties need to be combined - # TODO: Not Implemented yet! - # perform readout_operation on properties - for property, processing in zip(self.readout_prop, self.readout_operation): - outputs[property["out"]] = processing( - outputs[property["in"]], outputs[property["index_key"]] - ) + a = 7 + for key, value in outputs.items(): + if key in self._registered_properties: + outputs[key] = self.registered_chained_operations[key](outputs[key]) return outputs diff --git a/modelforge/potential/painn.py b/modelforge/potential/painn.py index ecc49dec..705b5793 100644 --- a/modelforge/potential/painn.py +++ b/modelforge/potential/painn.py @@ -205,7 +205,7 @@ def compute_properties( E_i = self.energy_layer(q).squeeze(1) return { - "E_i": E_i, + per_atom_energy: E_i, "mu": mu, "q": q, "atomic_subsystem_indices": data.atomic_subsystem_indices, diff --git a/modelforge/potential/physnet.py b/modelforge/potential/physnet.py index acf09117..cc97e9fa 100644 --- a/modelforge/potential/physnet.py +++ b/modelforge/potential/physnet.py @@ -494,7 +494,9 @@ def _model_specific_input_preparation( return nnp_input - def compute_properties(self, data: PhysNetNeuralNetworkData) -> Dict[str, torch.Tensor]: + def compute_properties( + self, data: PhysNetNeuralNetworkData + ) -> Dict[str, torch.Tensor]: """ Calculate the energy for a given input batch. Parameters @@ -566,7 +568,7 @@ def compute_properties(self, data: PhysNetNeuralNetworkData) -> Dict[str, torch. q_i = prediction_i_shifted_scaled[:, 1] # shape(nr_of_atoms, 1) output = { - "E_i": E_i.contiguous(), # reshape memory mapping for JAX/dlpack + per_atom_energy: E_i.contiguous(), # reshape memory mapping for JAX/dlpack "q_i": q_i.contiguous(), "atomic_subsystem_indices": data.atomic_subsystem_indices, "atomic_numbers": data.atomic_numbers, diff --git a/modelforge/potential/processing.py b/modelforge/potential/processing.py index 99293c1f..e9dacb8d 100644 --- a/modelforge/potential/processing.py +++ b/modelforge/potential/processing.py @@ -34,7 +34,11 @@ class FromAtomToMoleculeReduction(torch.nn.Module): def __init__( self, + per_atom_property_name: str, + index_name: str, + output_name: str, reduction_mode: str = "sum", + keep_per_atom_property: bool = False, ): """ Initializes the per-atom property readout_operation module. @@ -42,10 +46,11 @@ def __init__( """ super().__init__() self.reduction_mode = reduction_mode + self.per_atom_property_name = per_atom_property_name + self.output_name = output_name + self.index_name = index_name - def forward( - self, per_atom_property: torch.Tensor, index: torch.Tensor - ) -> torch.Tensor: + def forward(self, data: Dict[str, torch.Tensor]) -> torch.Tensor: """ Parameters @@ -57,7 +62,8 @@ def forward( ------- Tensor, shape [nr_of_moleculs, 1], the per-molecule property. """ - indices = index.to(torch.int64) + indices = data[self.index_name].to(torch.int64) + per_atom_property = data[self.per_atom_property_name] # Perform scatter add operation for atoms belonging to the same molecule property_per_molecule_zeros = torch.zeros( len(indices.unique()), @@ -68,7 +74,11 @@ def forward( property_per_molecule = property_per_molecule_zeros.scatter_reduce( 0, indices, per_atom_property, reduce=self.reduction_mode ) - return property_per_molecule + data[self.output_name] = property_per_molecule + if self.keep_per_atom_property is False: + del data[self.per_atom_property_name] + + return data from dataclasses import dataclass, field @@ -202,13 +212,8 @@ def __init__(self, atomic_self_energies) -> None: key: unit.Quantity(value) for key, value in atomic_self_energies.items() } self.atomic_self_energies = AtomicSelfEnergies(atomic_self_energies) - self.reduction = FromAtomToMoleculeReduction(reduction_mode="sum") - def forward( - self, - atomic_numbers: torch.Tensor, - atomic_subsystem_indices: torch.Tensor, - ) -> torch.Tensor: + def forward(self, data: Dict[str, torch.Tensor]) -> torch.Tensor: """ Calculates the molecular self energy. @@ -221,6 +226,8 @@ def forward( torch.Tensor The tensor containing the molecular self energy for each molecule. """ + atomic_numbers = data["atomic_numbers"] + atomic_subsystem_indices = data["atomic_subsystem_indices"] atomic_subsystem_indices = atomic_subsystem_indices.to( dtype=torch.long, device=atomic_numbers.device @@ -231,11 +238,9 @@ def forward( device=atomic_numbers.device ) - # first, we need to use the atomic numbers to generate a tensor that + # use the atomic numbers to generate a tensor that # contains the atomic self energy for each atomic number ase_tensor = ase_tensor_for_indexing[atomic_numbers] - # then we need to sum over atoms to get the molecular self energy - per_molecule_self_energy = self.reduction(ase_tensor, atomic_subsystem_indices) - - return per_molecule_self_energy + data["ase_tensor"] = ase_tensor + return data diff --git a/modelforge/potential/sake.py b/modelforge/potential/sake.py index e882f137..0b0246b7 100644 --- a/modelforge/potential/sake.py +++ b/modelforge/potential/sake.py @@ -171,7 +171,10 @@ def compute_properties(self, data: SAKENeuralNetworkInput): # Use squeeze to remove dimensions of size 1 E_i = self.energy_layer(h).squeeze(1) - return {"E_i": E_i, "atomic_subsystem_indices": data.atomic_subsystem_indices} + return { + per_atom_energy: E_i, + "atomic_subsystem_indices": data.atomic_subsystem_indices, + } class SAKEInteraction(nn.Module): diff --git a/modelforge/potential/schnet.py b/modelforge/potential/schnet.py index b7f50a53..fbd99763 100644 --- a/modelforge/potential/schnet.py +++ b/modelforge/potential/schnet.py @@ -206,7 +206,7 @@ def compute_properties( E_i = self.energy_layer(x).squeeze(1) return { - "E_i": E_i, + per_atom_energy: E_i, "scalar_representation": x, "atomic_subsystem_indices": data.atomic_subsystem_indices, } diff --git a/modelforge/tests/data/potential_defaults/ani2x.toml b/modelforge/tests/data/potential_defaults/ani2x.toml index 2476f6a9..82409313 100644 --- a/modelforge/tests/data/potential_defaults/ani2x.toml +++ b/modelforge/tests/data/potential_defaults/ani2x.toml @@ -12,8 +12,8 @@ angular_dist_divisions = 8 processing_operation = [ { in = [ - "E_i", - ], out = "E_i", step = "normalization", mean = "E_i_mean", stddev = "E_i_mean" }, + per_atom_energy, + ], out = per_atom_energy, step = "normalization", mean = "E_i_mean", stddev = "E_i_mean" }, { in = [ "atomic_numbers", 'atomic_subsystem_indices', diff --git a/modelforge/tests/data/potential_defaults/painn.toml b/modelforge/tests/data/potential_defaults/painn.toml index 4d99b88e..d0c90fae 100644 --- a/modelforge/tests/data/potential_defaults/painn.toml +++ b/modelforge/tests/data/potential_defaults/painn.toml @@ -13,8 +13,8 @@ shared_filters = false processing_operation = [ { in = [ - "E_i", - ], out = "E_i", step = "normalization", mean = "E_i_mean", stddev = "E_i_mean" }, + per_atom_energy, + ], out = per_atom_energy, step = "normalization", mean = "E_i_mean", stddev = "E_i_mean" }, { in = [ "atomic_numbers", 'atomic_subsystem_indices', diff --git a/modelforge/tests/data/potential_defaults/sake.toml b/modelforge/tests/data/potential_defaults/sake.toml index ade3f642..8b16d067 100644 --- a/modelforge/tests/data/potential_defaults/sake.toml +++ b/modelforge/tests/data/potential_defaults/sake.toml @@ -12,8 +12,8 @@ number_of_spatial_attention_heads = 4 processing_operation = [ { in = [ - "E_i", - ], out = "E_i", step = "normalization", mean = "E_i_mean", stddev = "E_i_mean" }, + per_atom_energy, + ], out = per_atom_energy, step = "normalization", mean = "E_i_mean", stddev = "E_i_mean" }, { in = [ "atomic_numbers", 'atomic_subsystem_indices', diff --git a/modelforge/tests/data/potential_defaults/schnet.toml b/modelforge/tests/data/potential_defaults/schnet.toml index 59c28487..68cb3173 100644 --- a/modelforge/tests/data/potential_defaults/schnet.toml +++ b/modelforge/tests/data/potential_defaults/schnet.toml @@ -11,7 +11,6 @@ number_of_filters = 32 shared_interactions = false [potential.postprocessing] -[potential.postprocessing.energy] +[potential.postprocessing.per_atom_energy] normalize = true from_atom_to_molecule_reduction = true - diff --git a/modelforge/tests/test_models.py b/modelforge/tests/test_models.py index 59799cba..297eff71 100644 --- a/modelforge/tests/test_models.py +++ b/modelforge/tests/test_models.py @@ -344,14 +344,14 @@ def test_forward_pass_with_all_datasets(model_name, dataset_name, datamodule_fac # test that the output has the following keys assert "E" in output - assert "E_i" in output + assert per_atom_energy in output assert "mse" in output assert "ase" in output assert output["E"].shape[0] == 64 assert output["mse"].shape[0] == 64 assert output["ase"].shape == batch.nnp_input.atomic_numbers.shape - assert output["E_i"].shape == batch.nnp_input.atomic_numbers.shape + assert output[per_atom_energy].shape == batch.nnp_input.atomic_numbers.shape pair_list = batch.nnp_input.pair_list # pairlist is in ascending order in row 0 @@ -396,12 +396,20 @@ def test_forward_pass( if "JAX" not in str(type(model)): # assert that the following tensor has equal values for dim=0 index 1 to 4 and 6 to 8 - assert torch.allclose(output["E_i"][1:4], output["E_i"][1], atol=1e-5) - assert torch.allclose(output["E_i"][6:8], output["E_i"][6], atol=1e-5) + assert torch.allclose( + output[per_atom_energy][1:4], output[per_atom_energy][1], atol=1e-5 + ) + assert torch.allclose( + output[per_atom_energy][6:8], output[per_atom_energy][6], atol=1e-5 + ) # make sure that the total energy is \sum E_i - assert torch.allclose(output["E"][0], output["E_i"][0:5].sum(dim=0), atol=1e-5) - assert torch.allclose(output["E"][1], output["E_i"][5:9].sum(dim=0), atol=1e-5) + assert torch.allclose( + output["E"][0], output[per_atom_energy][0:5].sum(dim=0), atol=1e-5 + ) + assert torch.allclose( + output["E"][1], output[per_atom_energy][5:9].sum(dim=0), atol=1e-5 + ) @pytest.mark.parametrize("model_name", _Implemented_NNPs.get_all_neural_network_names()) diff --git a/modelforge/tests/test_painn.py b/modelforge/tests/test_painn.py index aa5cbe31..3800600d 100644 --- a/modelforge/tests/test_painn.py +++ b/modelforge/tests/test_painn.py @@ -153,7 +153,6 @@ def test_equivariance(single_batch_with_batchsize_64): from modelforge.tests.test_schnet import setup_single_methane_input - def setup_representation( cutoff, nr_atom_basis, number_of_gaussians, nr_of_interactions ): @@ -176,7 +175,7 @@ def setup_representation( { "step": "from_atom_to_molecule", "mode": "sum", - "in": "E_i", + "in": per_atom_energy, "index_key": "atomic_subsystem_indices", "out": "E", } @@ -220,9 +219,7 @@ def test_compare_representation(): torch.manual_seed(1234) model.core_module.representation_module.filter_net.reset_parameters() - calculated_results = model.core_module.forward( - prepared_input, pairlist_output - ) + calculated_results = model.core_module.forward(prepared_input, pairlist_output) reference_results = load_precalculated_painn_results() # check that the scalar and vector representations are the same diff --git a/modelforge/tests/test_sake.py b/modelforge/tests/test_sake.py index 8e896ec2..495ccb6f 100644 --- a/modelforge/tests/test_sake.py +++ b/modelforge/tests/test_sake.py @@ -431,7 +431,7 @@ def test_sake_model_against_reference(single_batch_with_batchsize_1): { "step": "from_atom_to_molecule", "mode": "sum", - "in": "E_i", + "in": per_atom_energy, "index_key": "atomic_subsystem_indices", "out": "E", } diff --git a/modelforge/tests/test_schnet.py b/modelforge/tests/test_schnet.py index 8e847525..cf04d22e 100644 --- a/modelforge/tests/test_schnet.py +++ b/modelforge/tests/test_schnet.py @@ -31,7 +31,7 @@ def initialize_model( { "step": "from_atom_to_molecule", "mode": "sum", - "in": "E_i", + "in": per_atom_energy, "index_key": "atomic_subsystem_indices", "out": "E", } diff --git a/modelforge/tests/test_spk.py b/modelforge/tests/test_spk.py index 04204272..d4f6fedf 100644 --- a/modelforge/tests/test_spk.py +++ b/modelforge/tests/test_spk.py @@ -83,7 +83,7 @@ def setup_modelforge_painn_representation( { "step": "from_atom_to_molecule", "mode": "sum", - "in": "E_i", + "in": per_atom_energy, "index_key": "atomic_subsystem_indices", "out": "E", } @@ -463,7 +463,7 @@ def setup_mf_schnet_representation( { "step": "from_atom_to_molecule", "mode": "sum", - "in": "E_i", + "in": per_atom_energy, "index_key": "atomic_subsystem_indices", "out": "E", } From 645eb353b4af934d28b42dcfc38f46a1f6dcede9 Mon Sep 17 00:00:00 2001 From: wiederm Date: Wed, 10 Jul 2024 17:56:31 +0200 Subject: [PATCH 40/78] working prototype --- modelforge/potential/models.py | 21 ++++++++++++++------- modelforge/potential/processing.py | 14 ++++++++++---- modelforge/potential/schnet.py | 2 +- modelforge/tests/test_models.py | 5 ++++- 4 files changed, 29 insertions(+), 13 deletions(-) diff --git a/modelforge/potential/models.py b/modelforge/potential/models.py index a39e64e1..1b1ee072 100644 --- a/modelforge/potential/models.py +++ b/modelforge/potential/models.py @@ -765,7 +765,12 @@ def _initialize_postprocessing( if operation.lower() == "normalize" and property == "per_atom_energy": mean, stddev = self._get_mean_and_stddev_of_dataset() postprocessing_sequence.append( - ScaleValues(mean=mean, stddev=stddev) + ScaleValues( + mean=mean, + stddev=stddev, + property="per_atom_energy", + output_name="per_atom_energy", + ) ) prostprocessing_sequence_names.append(operation) # check if also reduction is requested @@ -808,17 +813,19 @@ def _initialize_postprocessing( self.registered_chained_operations[property] = postprocessing_sequence - def forward(self, outputs: Dict[str, torch.Tensor]): + def forward(self, data: Dict[str, torch.Tensor]): """ Perform post-processing operations for all registered properties. """ - a = 7 - for key, value in outputs.items(): - if key in self._registered_properties: - outputs[key] = self.registered_chained_operations[key](outputs[key]) - return outputs + # NOTE: this is not very elegant, but I am unsure how to do this better + # I am currently directly writing new keys and values in the data dictionary + for property in list(data.keys()): + if property in self._registered_properties: + self.registered_chained_operations[property](data) + + return data class BaseNetwork(Module): diff --git a/modelforge/potential/processing.py b/modelforge/potential/processing.py index e9dacb8d..4e91216d 100644 --- a/modelforge/potential/processing.py +++ b/modelforge/potential/processing.py @@ -49,8 +49,9 @@ def __init__( self.per_atom_property_name = per_atom_property_name self.output_name = output_name self.index_name = index_name + self.keep_per_atom_property = keep_per_atom_property - def forward(self, data: Dict[str, torch.Tensor]) -> torch.Tensor: + def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Parameters @@ -177,13 +178,17 @@ def ase_tensor_for_indexing(self) -> torch.Tensor: class ScaleValues(torch.nn.Module): - def __init__(self, mean: float, stddev: float) -> None: + def __init__( + self, mean: float, stddev: float, property: str, output_name: str + ) -> None: super().__init__() self.register_buffer("mean", torch.tensor([mean])) self.register_buffer("stddev", torch.tensor([stddev])) + self.property = property + self.output_name = output_name - def forward(self, values_to_be_scaled: torch.Tensor) -> torch.Tensor: + def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Rescales values using the provided mean and stddev. @@ -198,7 +203,8 @@ def forward(self, values_to_be_scaled: torch.Tensor) -> torch.Tensor: The rescaled values. """ - return values_to_be_scaled * self.stddev + self.mean + data[self.output_name] = data[self.property] * self.stddev + self.mean + return data class CalculateAtomicSelfEnergy(torch.nn.Module): diff --git a/modelforge/potential/schnet.py b/modelforge/potential/schnet.py index fbd99763..922239d4 100644 --- a/modelforge/potential/schnet.py +++ b/modelforge/potential/schnet.py @@ -206,7 +206,7 @@ def compute_properties( E_i = self.energy_layer(x).squeeze(1) return { - per_atom_energy: E_i, + 'per_atom_energy': E_i, "scalar_representation": x, "atomic_subsystem_indices": data.atomic_subsystem_indices, } diff --git a/modelforge/tests/test_models.py b/modelforge/tests/test_models.py index 297eff71..679b6b72 100644 --- a/modelforge/tests/test_models.py +++ b/modelforge/tests/test_models.py @@ -387,7 +387,10 @@ def test_forward_pass( output = model(nnp_input) # test tat we get an energie per molecule - assert len(output["E"]) == nr_of_mols + assert len(output["per_molecule_energy"]) == nr_of_mols + + + # TEST WORKS UNTIL HERE # the batch consists of methane (CH4) and amamonium (NH3) # which has symmetric hydrogens. From 3160923c1d7cf7962c8f0b994e8ade19a4c2b63d Mon Sep 17 00:00:00 2001 From: wiederm Date: Thu, 11 Jul 2024 11:07:45 +0200 Subject: [PATCH 41/78] update all toml files and all potential output names --- modelforge/potential/ani.py | 2 +- modelforge/potential/models.py | 74 ++++++++++--------- modelforge/potential/painn.py | 2 +- modelforge/potential/physnet.py | 2 +- modelforge/potential/sake.py | 2 +- .../tests/data/potential_defaults/ani2x.toml | 20 ++--- .../tests/data/potential_defaults/painn.toml | 20 ++--- .../tests/data/potential_defaults/sake.toml | 20 ++--- .../tests/data/potential_defaults/schnet.toml | 3 + 9 files changed, 61 insertions(+), 84 deletions(-) diff --git a/modelforge/potential/ani.py b/modelforge/potential/ani.py index ac1262bf..d9084395 100644 --- a/modelforge/potential/ani.py +++ b/modelforge/potential/ani.py @@ -540,7 +540,7 @@ def compute_properties(self, data: AniNeuralNetworkData) -> Dict[str, torch.Tens E_i = self.interaction_modules(representation) return { - per_atom_energy: E_i, + 'per_atom_energy': E_i, "atomic_subsystem_indices": data.atomic_subsystem_indices, } diff --git a/modelforge/potential/models.py b/modelforge/potential/models.py index 1b1ee072..2fa272b5 100644 --- a/modelforge/potential/models.py +++ b/modelforge/potential/models.py @@ -710,8 +710,6 @@ def __init__( # operations that use nn.Sequence to pass the output of the model to the next self.registered_chained_operations = ModuleDict() - # operations that don't requre any nn.Sequence - self.registered_independent_operations = ModuleDict() self.dataset_statistic = dataset_statistic @@ -761,8 +759,9 @@ def _initialize_postprocessing( postprocessing_sequence = torch.nn.Sequential() prostprocessing_sequence_names = [] - for operation in operations: - if operation.lower() == "normalize" and property == "per_atom_energy": + # for each property parse the requested operations + if property == "per_atom_energy": + if operations.get("normalize", False): mean, stddev = self._get_mean_and_stddev_of_dataset() postprocessing_sequence.append( ScaleValues( @@ -772,42 +771,48 @@ def _initialize_postprocessing( output_name="per_atom_energy", ) ) - prostprocessing_sequence_names.append(operation) - # check if also reduction is requested - for operation in operations: - if operation.lower() == "from_atom_to_molecule_reduction": - postprocessing_sequence.append( - FromAtomToMoleculeReduction( - per_atom_property_name="per_atom_energy", - index_name="atomic_subsystem_indices", - output_name="per_molecule_energy", - ) - ) - prostprocessing_sequence_names.append(operation) - - elif operation.lower() == "calculate_atomic_self_energy": - atomic_self_energies = self.dataset_statistic[ - "atomic_self_energies" - ] - postprocessing_sequence.append( - CalculateAtomicSelfEnergy(atomic_self_energies)() - ) - prostprocessing_sequence_names.append(operation) - + prostprocessing_sequence_names.append("normalize") + # check if also reduction is requested + if operations.get("from_atom_to_molecule_reduction", False): postprocessing_sequence.append( FromAtomToMoleculeReduction( - per_atom_property_name="ase_tensor", + per_atom_property_name="per_atom_energy", index_name="atomic_subsystem_indices", - output_name="per_molecule_self_energy", + output_name="per_molecule_energy", + keep_per_atom_property=operations.get( + "keep_per_atom_property", False + ), ) ) + prostprocessing_sequence_names.append( + "from_atom_to_molecule_reduction" + ) + + # check if also self-energies are requested + if operations.get("calculate_molecular_self_energy", False): + atomic_self_energies = self.dataset_statistic["atomic_self_energies"] + + postprocessing_sequence.append( + CalculateAtomicSelfEnergy(atomic_self_energies)() + ) + prostprocessing_sequence_names.append("calculate_molecular_self_energy") + + postprocessing_sequence.append( + FromAtomToMoleculeReduction( + per_atom_property_name="ase_tensor", + index_name="atomic_subsystem_indices", + output_name="per_molecule_self_energy", + ) + ) + + # check if also self-energies are requested + if operations.get("calculate_atomic_self_energy", False): + atomic_self_energies = self.dataset_statistic["atomic_self_energies"] - elif ( - operation.lower() == "from_atom_to_molecule_reduction" - and operation.lower() not in prostprocessing_sequence_names - ): - postprocessing_sequence.append(FromAtomToMoleculeReduction()) - prostprocessing_sequence_names.append(operation) + postprocessing_sequence.append( + CalculateAtomicSelfEnergy(atomic_self_energies)() + ) + prostprocessing_sequence_names.append("calculate_atomic_self_energy") log.debug(prostprocessing_sequence_names) @@ -818,7 +823,6 @@ def forward(self, data: Dict[str, torch.Tensor]): Perform post-processing operations for all registered properties. """ - # NOTE: this is not very elegant, but I am unsure how to do this better # I am currently directly writing new keys and values in the data dictionary for property in list(data.keys()): diff --git a/modelforge/potential/painn.py b/modelforge/potential/painn.py index 705b5793..53ab0782 100644 --- a/modelforge/potential/painn.py +++ b/modelforge/potential/painn.py @@ -205,7 +205,7 @@ def compute_properties( E_i = self.energy_layer(q).squeeze(1) return { - per_atom_energy: E_i, + 'per_atom_energy': E_i, "mu": mu, "q": q, "atomic_subsystem_indices": data.atomic_subsystem_indices, diff --git a/modelforge/potential/physnet.py b/modelforge/potential/physnet.py index cc97e9fa..d7589c30 100644 --- a/modelforge/potential/physnet.py +++ b/modelforge/potential/physnet.py @@ -568,7 +568,7 @@ def compute_properties( q_i = prediction_i_shifted_scaled[:, 1] # shape(nr_of_atoms, 1) output = { - per_atom_energy: E_i.contiguous(), # reshape memory mapping for JAX/dlpack + 'per_atom_energy': E_i.contiguous(), # reshape memory mapping for JAX/dlpack "q_i": q_i.contiguous(), "atomic_subsystem_indices": data.atomic_subsystem_indices, "atomic_numbers": data.atomic_numbers, diff --git a/modelforge/potential/sake.py b/modelforge/potential/sake.py index 0b0246b7..c7b20832 100644 --- a/modelforge/potential/sake.py +++ b/modelforge/potential/sake.py @@ -172,7 +172,7 @@ def compute_properties(self, data: SAKENeuralNetworkInput): E_i = self.energy_layer(h).squeeze(1) return { - per_atom_energy: E_i, + 'per_atom_energy': E_i, "atomic_subsystem_indices": data.atomic_subsystem_indices, } diff --git a/modelforge/tests/data/potential_defaults/ani2x.toml b/modelforge/tests/data/potential_defaults/ani2x.toml index 82409313..5b1a7070 100644 --- a/modelforge/tests/data/potential_defaults/ani2x.toml +++ b/modelforge/tests/data/potential_defaults/ani2x.toml @@ -10,18 +10,8 @@ angular_max_distance = "3.5 angstrom" angular_min_distance = "0.8 angstrom" angular_dist_divisions = 8 -processing_operation = [ - { in = [ - per_atom_energy, - ], out = per_atom_energy, step = "normalization", mean = "E_i_mean", stddev = "E_i_mean" }, - { in = [ - "atomic_numbers", - 'atomic_subsystem_indices', - ], out = "ase", step = "calculate_ase" }, - -] -readout_operation = [ - { step = "from_atom_to_molecule", mode = 'sum', in = 'E_i', index_key = 'atomic_subsystem_indices', out = 'E' }, - { step = "from_atom_to_molecule", mode = 'sum', in = 'ase', index_key = 'atomic_subsystem_indices', out = 'mse' }, - -] +[potential.postprocessing] +[potential.postprocessing.per_atom_energy] +normalize = true +from_atom_to_molecule_reduction = true +keep_per_atom_property = true diff --git a/modelforge/tests/data/potential_defaults/painn.toml b/modelforge/tests/data/potential_defaults/painn.toml index d0c90fae..f0e3623d 100644 --- a/modelforge/tests/data/potential_defaults/painn.toml +++ b/modelforge/tests/data/potential_defaults/painn.toml @@ -11,18 +11,8 @@ number_of_interaction_modules = 3 shared_interactions = false shared_filters = false -processing_operation = [ - { in = [ - per_atom_energy, - ], out = per_atom_energy, step = "normalization", mean = "E_i_mean", stddev = "E_i_mean" }, - { in = [ - "atomic_numbers", - 'atomic_subsystem_indices', - ], out = "ase", step = "calculate_ase" }, - -] -readout_operation = [ - { step = "from_atom_to_molecule", mode = 'sum', in = 'E_i', index_key = 'atomic_subsystem_indices', out = 'E' }, - { step = "from_atom_to_molecule", mode = 'sum', in = 'ase', index_key = 'atomic_subsystem_indices', out = 'mse' }, - -] +[potential.postprocessing] +[potential.postprocessing.per_atom_energy] +normalize = true +from_atom_to_molecule_reduction = true +keep_per_atom_property = true diff --git a/modelforge/tests/data/potential_defaults/sake.toml b/modelforge/tests/data/potential_defaults/sake.toml index 8b16d067..408846ab 100644 --- a/modelforge/tests/data/potential_defaults/sake.toml +++ b/modelforge/tests/data/potential_defaults/sake.toml @@ -10,18 +10,8 @@ cutoff = "5.0 angstrom" number_of_interaction_modules = 6 number_of_spatial_attention_heads = 4 -processing_operation = [ - { in = [ - per_atom_energy, - ], out = per_atom_energy, step = "normalization", mean = "E_i_mean", stddev = "E_i_mean" }, - { in = [ - "atomic_numbers", - 'atomic_subsystem_indices', - ], out = "ase", step = "calculate_ase" }, - -] -readout_operation = [ - { step = "from_atom_to_molecule", mode = 'sum', in = 'E_i', index_key = 'atomic_subsystem_indices', out = 'E' }, - { step = "from_atom_to_molecule", mode = 'sum', in = 'ase', index_key = 'atomic_subsystem_indices', out = 'mse' }, - -] +[potential.postprocessing] +[potential.postprocessing.per_atom_energy] +normalize = true +from_atom_to_molecule_reduction = true +keep_per_atom_property = true diff --git a/modelforge/tests/data/potential_defaults/schnet.toml b/modelforge/tests/data/potential_defaults/schnet.toml index 68cb3173..5c0491d5 100644 --- a/modelforge/tests/data/potential_defaults/schnet.toml +++ b/modelforge/tests/data/potential_defaults/schnet.toml @@ -14,3 +14,6 @@ shared_interactions = false [potential.postprocessing.per_atom_energy] normalize = true from_atom_to_molecule_reduction = true +keep_per_atom_property = true +calculate_molecular_self_energy = true +#calculate_atomic_self_energy = true From 77756f9d2fde03e2278de48c3ebea46631c627a4 Mon Sep 17 00:00:00 2001 From: wiederm Date: Thu, 11 Jul 2024 11:29:24 +0200 Subject: [PATCH 42/78] update input signature and toml files --- modelforge/potential/ani.py | 8 ++- modelforge/potential/models.py | 51 +++++++++++++------ modelforge/potential/painn.py | 8 ++- modelforge/potential/physnet.py | 8 ++- modelforge/potential/sake.py | 8 ++- .../tests/data/potential_defaults/ani2x.toml | 2 +- .../tests/data/potential_defaults/painn.toml | 2 +- .../data/potential_defaults/physnet.toml | 22 +++----- .../tests/data/potential_defaults/sake.toml | 2 +- modelforge/tests/test_models.py | 25 ++++----- modelforge/tests/test_sake.py | 2 +- modelforge/tests/test_schnet.py | 2 +- modelforge/tests/test_spk.py | 4 +- 13 files changed, 70 insertions(+), 74 deletions(-) diff --git a/modelforge/potential/ani.py b/modelforge/potential/ani.py index d9084395..b6c14c15 100644 --- a/modelforge/potential/ani.py +++ b/modelforge/potential/ani.py @@ -540,7 +540,7 @@ def compute_properties(self, data: AniNeuralNetworkData) -> Dict[str, torch.Tens E_i = self.interaction_modules(representation) return { - 'per_atom_energy': E_i, + "per_atom_energy": E_i, "atomic_subsystem_indices": data.atomic_subsystem_indices, } @@ -558,15 +558,13 @@ def __init__( angular_min_distance: Union[unit.Quantity, str], angular_dist_divisions: int, angle_sections: int, - processing_operation: List[Dict[str, str]], - readout_operation: List[Dict[str, str]], + postprocessing_parameter: Dict[str, Dict[str, bool]], dataset_statistic: Optional[Dict[str, float]] = None, ) -> None: super().__init__( - processing_operation=processing_operation, dataset_statistic=dataset_statistic, - readout_operation=readout_operation, + postprocessing_parameter=postprocessing_parameter, ) from modelforge.utils.units import _convert diff --git a/modelforge/potential/models.py b/modelforge/potential/models.py index 2fa272b5..23464889 100644 --- a/modelforge/potential/models.py +++ b/modelforge/potential/models.py @@ -790,29 +790,48 @@ def _initialize_postprocessing( # check if also self-energies are requested if operations.get("calculate_molecular_self_energy", False): - atomic_self_energies = self.dataset_statistic["atomic_self_energies"] - postprocessing_sequence.append( - CalculateAtomicSelfEnergy(atomic_self_energies)() - ) - prostprocessing_sequence_names.append("calculate_molecular_self_energy") + if self.dataset_statistic is None: + log.warning( + "Dataset statistics are required to calculate the molecular self-energies but haven't been provided." + ) + else: + atomic_self_energies = self.dataset_statistic[ + "atomic_self_energies" + ] - postprocessing_sequence.append( - FromAtomToMoleculeReduction( - per_atom_property_name="ase_tensor", - index_name="atomic_subsystem_indices", - output_name="per_molecule_self_energy", + postprocessing_sequence.append( + CalculateAtomicSelfEnergy(atomic_self_energies)() + ) + prostprocessing_sequence_names.append( + "calculate_molecular_self_energy" + ) + + postprocessing_sequence.append( + FromAtomToMoleculeReduction( + per_atom_property_name="ase_tensor", + index_name="atomic_subsystem_indices", + output_name="per_molecule_self_energy", + ) ) - ) # check if also self-energies are requested if operations.get("calculate_atomic_self_energy", False): - atomic_self_energies = self.dataset_statistic["atomic_self_energies"] + if self.dataset_statistic is None: + log.warning( + "Dataset statistics are required to calculate the molecular self-energies but haven't been provided." + ) + else: + atomic_self_energies = self.dataset_statistic[ + "atomic_self_energies" + ] - postprocessing_sequence.append( - CalculateAtomicSelfEnergy(atomic_self_energies)() - ) - prostprocessing_sequence_names.append("calculate_atomic_self_energy") + postprocessing_sequence.append( + CalculateAtomicSelfEnergy(atomic_self_energies)() + ) + prostprocessing_sequence_names.append( + "calculate_atomic_self_energy" + ) log.debug(prostprocessing_sequence_names) diff --git a/modelforge/potential/painn.py b/modelforge/potential/painn.py index 53ab0782..e34f6de9 100644 --- a/modelforge/potential/painn.py +++ b/modelforge/potential/painn.py @@ -205,7 +205,7 @@ def compute_properties( E_i = self.energy_layer(q).squeeze(1) return { - 'per_atom_energy': E_i, + "per_atom_energy": E_i, "mu": mu, "q": q, "atomic_subsystem_indices": data.atomic_subsystem_indices, @@ -484,15 +484,13 @@ def __init__( number_of_interaction_modules: int, shared_interactions: bool, shared_filters: bool, - processing_operation: Dict[str, torch.nn.ModuleList], - readout_operation: Dict[str, List[Dict[str, str]]], + postprocessing_parameter: Dict[str, Dict[str, bool]], dataset_statistic: Optional[Dict[str, float]] = None, epsilon: float = 1e-8, ) -> None: super().__init__( dataset_statistic=dataset_statistic, - processing_operation=processing_operation, - readout_operation=readout_operation, + postprocessing_parameter=postprocessing_parameter, ) from modelforge.utils.units import _convert diff --git a/modelforge/potential/physnet.py b/modelforge/potential/physnet.py index d7589c30..9c046aac 100644 --- a/modelforge/potential/physnet.py +++ b/modelforge/potential/physnet.py @@ -568,7 +568,7 @@ def compute_properties( q_i = prediction_i_shifted_scaled[:, 1] # shape(nr_of_atoms, 1) output = { - 'per_atom_energy': E_i.contiguous(), # reshape memory mapping for JAX/dlpack + "per_atom_energy": E_i.contiguous(), # reshape memory mapping for JAX/dlpack "q_i": q_i.contiguous(), "atomic_subsystem_indices": data.atomic_subsystem_indices, "atomic_numbers": data.atomic_numbers, @@ -590,8 +590,7 @@ def __init__( number_of_radial_basis_functions: int, number_of_interaction_residual: int, number_of_modules: int, - processing_operation: List[Dict[str, str]], - readout_operation: List[Dict[str, str]], + postprocessing_parameter: Dict[str, Dict[str, bool]], dataset_statistic: Optional[Dict[str, float]] = None, ) -> None: """ @@ -602,8 +601,7 @@ def __init__( """ super().__init__( dataset_statistic=dataset_statistic, - processing_operation=processing_operation, - readout_operation=readout_operation, + postprocessing_parameter=postprocessing_parameter, ) from modelforge.utils.units import _convert diff --git a/modelforge/potential/sake.py b/modelforge/potential/sake.py index c7b20832..88d88f8f 100644 --- a/modelforge/potential/sake.py +++ b/modelforge/potential/sake.py @@ -172,7 +172,7 @@ def compute_properties(self, data: SAKENeuralNetworkInput): E_i = self.energy_layer(h).squeeze(1) return { - 'per_atom_energy': E_i, + "per_atom_energy": E_i, "atomic_subsystem_indices": data.atomic_subsystem_indices, } @@ -560,15 +560,13 @@ def __init__( number_of_spatial_attention_heads: int, number_of_radial_basis_functions: int, cutoff: unit.Quantity, - processing_operation: List[Dict[str, str]], - readout_operation: List[Dict[str, str]], + postprocessing_parameter: Dict[str, Dict[str, bool]], dataset_statistic: Optional[Dict[str, float]] = None, epsilon: float = 1e-8, ): super().__init__( dataset_statistic=dataset_statistic, - processing_operation=processing_operation, - readout_operation=readout_operation, + postprocessing_parameter=postprocessing_parameter, ) from modelforge.utils.units import _convert diff --git a/modelforge/tests/data/potential_defaults/ani2x.toml b/modelforge/tests/data/potential_defaults/ani2x.toml index 5b1a7070..d42275df 100644 --- a/modelforge/tests/data/potential_defaults/ani2x.toml +++ b/modelforge/tests/data/potential_defaults/ani2x.toml @@ -1,7 +1,7 @@ [potential] model_name = "ANI2x" -[potential.potential_parameter] +[potential.core_parameter] angle_sections = 4 radial_max_distance = "5.1 angstrom" radial_min_distance = "0.8 angstrom" diff --git a/modelforge/tests/data/potential_defaults/painn.toml b/modelforge/tests/data/potential_defaults/painn.toml index f0e3623d..2eeb83da 100644 --- a/modelforge/tests/data/potential_defaults/painn.toml +++ b/modelforge/tests/data/potential_defaults/painn.toml @@ -1,7 +1,7 @@ [potential] model_name = "PaiNN" -[potential.potential_parameter] +[potential.core_parameter] max_Z = 101 number_of_atom_features = 32 diff --git a/modelforge/tests/data/potential_defaults/physnet.toml b/modelforge/tests/data/potential_defaults/physnet.toml index 5ba2b507..40232839 100644 --- a/modelforge/tests/data/potential_defaults/physnet.toml +++ b/modelforge/tests/data/potential_defaults/physnet.toml @@ -1,7 +1,7 @@ [potential] model_name = "PhysNet" -[potential.potential_parameter] +[potential.core_parameter] max_Z = 101 number_of_atom_features = 64 @@ -10,18 +10,8 @@ cutoff = "5.0 angstrom" number_of_interaction_residual = 3 number_of_modules = 5 -processing_operation = [ - { in = [ - "E_i", - ], out = "E_i", step = "normalization", mean = "E_i_mean", stddev = "E_i_mean" }, - { in = [ - "atomic_numbers", - 'atomic_subsystem_indices', - ], out = "ase", step = "calculate_ase" }, - -] -readout_operation = [ - { step = "from_atom_to_molecule", mode = 'sum', in = 'E_i', index_key = 'atomic_subsystem_indices', out = 'E' }, - { step = "from_atom_to_molecule", mode = 'sum', in = 'ase', index_key = 'atomic_subsystem_indices', out = 'mse' }, - -] +[potential.postprocessing] +[potential.postprocessing.per_atom_energy] +normalize = true +from_atom_to_molecule_reduction = true +keep_per_atom_property = true diff --git a/modelforge/tests/data/potential_defaults/sake.toml b/modelforge/tests/data/potential_defaults/sake.toml index 408846ab..d604b640 100644 --- a/modelforge/tests/data/potential_defaults/sake.toml +++ b/modelforge/tests/data/potential_defaults/sake.toml @@ -1,7 +1,7 @@ [potential] model_name = "SAKE" -[potential.potential_parameter] +[potential.core_parameter] max_Z = 101 number_of_atom_features = 64 diff --git a/modelforge/tests/test_models.py b/modelforge/tests/test_models.py index 679b6b72..12a4901c 100644 --- a/modelforge/tests/test_models.py +++ b/modelforge/tests/test_models.py @@ -344,14 +344,14 @@ def test_forward_pass_with_all_datasets(model_name, dataset_name, datamodule_fac # test that the output has the following keys assert "E" in output - assert per_atom_energy in output + assert 'per_atom_energy' in output assert "mse" in output assert "ase" in output assert output["E"].shape[0] == 64 assert output["mse"].shape[0] == 64 assert output["ase"].shape == batch.nnp_input.atomic_numbers.shape - assert output[per_atom_energy].shape == batch.nnp_input.atomic_numbers.shape + assert output['per_atom_energy'].shape == batch.nnp_input.atomic_numbers.shape pair_list = batch.nnp_input.pair_list # pairlist is in ascending order in row 0 @@ -367,19 +367,15 @@ def test_forward_pass( import torch nnp_input = single_batch_with_batchsize_64.nnp_input - nr_of_mols = nnp_input.atomic_subsystem_indices.unique().shape[0] # read default parameters config = load_configs(f"{model_name.lower()}", "qm9") - # Extract parameters - model_parameter = config["potential"] - # test the forward pass through each of the models model = NeuralNetworkPotentialFactory.generate_model( use="inference", simulation_environment=simulation_environment, - model_parameter=model_parameter, + model_parameter=config["potential"], ) if "JAX" in str(type(model)): nnp_input = nnp_input.as_jax_namedtuple() @@ -387,31 +383,30 @@ def test_forward_pass( output = model(nnp_input) # test tat we get an energie per molecule + nr_of_mols = nnp_input.atomic_subsystem_indices.unique().shape[0] assert len(output["per_molecule_energy"]) == nr_of_mols - # TEST WORKS UNTIL HERE - # the batch consists of methane (CH4) and amamonium (NH3) - # which has symmetric hydrogens. + # which have chemically equivalent hydrogens at the minimum geometry. # This has to be reflected in the atomic energies E_i, which - # has to be equal for all hydrogens + # have to be equal for all hydrogens if "JAX" not in str(type(model)): # assert that the following tensor has equal values for dim=0 index 1 to 4 and 6 to 8 assert torch.allclose( - output[per_atom_energy][1:4], output[per_atom_energy][1], atol=1e-5 + output['per_atom_energy'][1:4], output['per_atom_energy'][1], atol=1e-5 ) assert torch.allclose( - output[per_atom_energy][6:8], output[per_atom_energy][6], atol=1e-5 + output['per_atom_energy'][6:8], output['per_atom_energy'][6], atol=1e-5 ) # make sure that the total energy is \sum E_i assert torch.allclose( - output["E"][0], output[per_atom_energy][0:5].sum(dim=0), atol=1e-5 + output["per_molecule_energy"][0], output['per_atom_energy'][0:5].sum(dim=0), atol=1e-5 ) assert torch.allclose( - output["E"][1], output[per_atom_energy][5:9].sum(dim=0), atol=1e-5 + output["per_molecule_energy"][1], output['per_atom_energy'][5:9].sum(dim=0), atol=1e-5 ) diff --git a/modelforge/tests/test_sake.py b/modelforge/tests/test_sake.py index 495ccb6f..ccfe4771 100644 --- a/modelforge/tests/test_sake.py +++ b/modelforge/tests/test_sake.py @@ -431,7 +431,7 @@ def test_sake_model_against_reference(single_batch_with_batchsize_1): { "step": "from_atom_to_molecule", "mode": "sum", - "in": per_atom_energy, + "in": 'per_atom_energy', "index_key": "atomic_subsystem_indices", "out": "E", } diff --git a/modelforge/tests/test_schnet.py b/modelforge/tests/test_schnet.py index cf04d22e..8e29df07 100644 --- a/modelforge/tests/test_schnet.py +++ b/modelforge/tests/test_schnet.py @@ -31,7 +31,7 @@ def initialize_model( { "step": "from_atom_to_molecule", "mode": "sum", - "in": per_atom_energy, + "in": 'per_atom_energy', "index_key": "atomic_subsystem_indices", "out": "E", } diff --git a/modelforge/tests/test_spk.py b/modelforge/tests/test_spk.py index d4f6fedf..c1da58b2 100644 --- a/modelforge/tests/test_spk.py +++ b/modelforge/tests/test_spk.py @@ -83,7 +83,7 @@ def setup_modelforge_painn_representation( { "step": "from_atom_to_molecule", "mode": "sum", - "in": per_atom_energy, + "in": 'per_atom_energy', "index_key": "atomic_subsystem_indices", "out": "E", } @@ -463,7 +463,7 @@ def setup_mf_schnet_representation( { "step": "from_atom_to_molecule", "mode": "sum", - "in": per_atom_energy, + "in": 'per_atom_energy', "index_key": "atomic_subsystem_indices", "out": "E", } From 7f841a9585fa00f0d57615d43871c7b8849ec250 Mon Sep 17 00:00:00 2001 From: wiederm Date: Thu, 11 Jul 2024 11:30:58 +0200 Subject: [PATCH 43/78] remove legacy code --- modelforge/potential/painn.py | 5 ----- modelforge/potential/sake.py | 2 -- 2 files changed, 7 deletions(-) diff --git a/modelforge/potential/painn.py b/modelforge/potential/painn.py index e34f6de9..ccc90c69 100644 --- a/modelforge/potential/painn.py +++ b/modelforge/potential/painn.py @@ -104,11 +104,6 @@ def __init__( self.embedding_module = Embedding(max_Z, number_of_atom_features) - # initialize the energy readout_operation - from .processing import FromAtomToMoleculeReduction - - self.readout_module = FromAtomToMoleculeReduction() - # initialize representation block self.representation_module = PaiNNRepresentation( cutoff, diff --git a/modelforge/potential/sake.py b/modelforge/potential/sake.py index 88d88f8f..830fa646 100644 --- a/modelforge/potential/sake.py +++ b/modelforge/potential/sake.py @@ -99,8 +99,6 @@ def __init__( nn.SiLU(), Dense(number_of_atom_features, 1), ) - self.readout_module = FromAtomToMoleculeReduction() - # initialize the interaction networks self.interaction_modules = nn.ModuleList( SAKEInteraction( From 93f0f2866689ba0a3afd153085c783b2feb70512 Mon Sep 17 00:00:00 2001 From: wiederm Date: Thu, 11 Jul 2024 15:31:18 +0200 Subject: [PATCH 44/78] update --- modelforge/potential/models.py | 13 +- .../tests/data/potential_defaults/ani2x.toml | 4 +- .../tests/data/potential_defaults/painn.toml | 4 +- .../data/potential_defaults/physnet.toml | 4 +- .../tests/data/potential_defaults/sake.toml | 4 +- .../tests/data/potential_defaults/schnet.toml | 4 +- .../tests/data/training_defaults/default.toml | 12 +- modelforge/tests/test_models.py | 141 +++++------ modelforge/train/training.py | 237 ++++++------------ 9 files changed, 164 insertions(+), 259 deletions(-) diff --git a/modelforge/potential/models.py b/modelforge/potential/models.py index 23464889..0a675c4f 100644 --- a/modelforge/potential/models.py +++ b/modelforge/potential/models.py @@ -569,7 +569,6 @@ def generate_model( log.warning( "Training in JAX is not availalbe. Falling back to PyTorch." ) - model_parameter["model_name"] = model_type model = TrainingAdapter( model_parameter=model_parameter, **training_parameter, @@ -582,7 +581,7 @@ def generate_model( nnp_class: Type = _Implemented_NNPs.get_neural_network_class(model_type) model = nnp_class( **model_parameter["core_parameter"], - postprocessing_parameter=model_parameter["postprocessing"], + postprocessing_parameter=model_parameter["postprocessing_parameter"], dataset_statistic=dataset_statistic, ) if simulation_environment == "JAX": @@ -590,7 +589,7 @@ def generate_model( else: return model else: - raise ValueError(f"Unsupported 'use' value: {use}") + raise NotImplementedError(f"Unsupported 'use' value: {use}") class InputPreparation(torch.nn.Module): @@ -727,10 +726,10 @@ def _get_mean_and_stddev_of_dataset(self) -> Tuple[float, float]: ) else: atomic_energies_stats = self.dataset_statistic["atomic_energies_stats"] - mean = unit.Quantity(atomic_energies_stats["mean"]).m_as( + mean = unit.Quantity(atomic_energies_stats["E_i_mean"]).m_as( unit.kilojoule_per_mole ) - stddev = unit.Quantity(atomic_energies_stats["stddev"]).m_as( + stddev = unit.Quantity(atomic_energies_stats["E_i_stddev"]).m_as( unit.kilojoule_per_mole ) return mean, stddev @@ -801,7 +800,7 @@ def _initialize_postprocessing( ] postprocessing_sequence.append( - CalculateAtomicSelfEnergy(atomic_self_energies)() + CalculateAtomicSelfEnergy(atomic_self_energies) ) prostprocessing_sequence_names.append( "calculate_molecular_self_energy" @@ -894,7 +893,7 @@ def load_state_dict( # Prefix to remove prefix = "model." - excluded_keys = ["loss_module.energy_weight", "loss_module.force_weight"] + excluded_keys = ["loss.per_molecule_energy", "loss.force"] # Create a new dictionary without the prefix in the keys if prefix exists if any(key.startswith(prefix) for key in state_dict.keys()): diff --git a/modelforge/tests/data/potential_defaults/ani2x.toml b/modelforge/tests/data/potential_defaults/ani2x.toml index d42275df..05cae5b9 100644 --- a/modelforge/tests/data/potential_defaults/ani2x.toml +++ b/modelforge/tests/data/potential_defaults/ani2x.toml @@ -10,8 +10,8 @@ angular_max_distance = "3.5 angstrom" angular_min_distance = "0.8 angstrom" angular_dist_divisions = 8 -[potential.postprocessing] -[potential.postprocessing.per_atom_energy] +[potential.postprocessing_parameter] +[potential.postprocessing_parameter.per_atom_energy] normalize = true from_atom_to_molecule_reduction = true keep_per_atom_property = true diff --git a/modelforge/tests/data/potential_defaults/painn.toml b/modelforge/tests/data/potential_defaults/painn.toml index 2eeb83da..70292773 100644 --- a/modelforge/tests/data/potential_defaults/painn.toml +++ b/modelforge/tests/data/potential_defaults/painn.toml @@ -11,8 +11,8 @@ number_of_interaction_modules = 3 shared_interactions = false shared_filters = false -[potential.postprocessing] -[potential.postprocessing.per_atom_energy] +[potential.postprocessing_parameter] +[potential.postprocessing_parameter.per_atom_energy] normalize = true from_atom_to_molecule_reduction = true keep_per_atom_property = true diff --git a/modelforge/tests/data/potential_defaults/physnet.toml b/modelforge/tests/data/potential_defaults/physnet.toml index 40232839..68b76d91 100644 --- a/modelforge/tests/data/potential_defaults/physnet.toml +++ b/modelforge/tests/data/potential_defaults/physnet.toml @@ -10,8 +10,8 @@ cutoff = "5.0 angstrom" number_of_interaction_residual = 3 number_of_modules = 5 -[potential.postprocessing] -[potential.postprocessing.per_atom_energy] +[potential.postprocessing_parameter] +[potential.postprocessing_parameter.per_atom_energy] normalize = true from_atom_to_molecule_reduction = true keep_per_atom_property = true diff --git a/modelforge/tests/data/potential_defaults/sake.toml b/modelforge/tests/data/potential_defaults/sake.toml index d604b640..d8fb2cc5 100644 --- a/modelforge/tests/data/potential_defaults/sake.toml +++ b/modelforge/tests/data/potential_defaults/sake.toml @@ -10,8 +10,8 @@ cutoff = "5.0 angstrom" number_of_interaction_modules = 6 number_of_spatial_attention_heads = 4 -[potential.postprocessing] -[potential.postprocessing.per_atom_energy] +[potential.postprocessing_parameter] +[potential.postprocessing_parameter.per_atom_energy] normalize = true from_atom_to_molecule_reduction = true keep_per_atom_property = true diff --git a/modelforge/tests/data/potential_defaults/schnet.toml b/modelforge/tests/data/potential_defaults/schnet.toml index 5c0491d5..8626206f 100644 --- a/modelforge/tests/data/potential_defaults/schnet.toml +++ b/modelforge/tests/data/potential_defaults/schnet.toml @@ -10,8 +10,8 @@ number_of_interaction_modules = 3 number_of_filters = 32 shared_interactions = false -[potential.postprocessing] -[potential.postprocessing.per_atom_energy] +[potential.postprocessing_parameter] +[potential.postprocessing_parameter.per_atom_energy] normalize = true from_atom_to_molecule_reduction = true keep_per_atom_property = true diff --git a/modelforge/tests/data/training_defaults/default.toml b/modelforge/tests/data/training_defaults/default.toml index d6211ac8..664e32fa 100644 --- a/modelforge/tests/data/training_defaults/default.toml +++ b/modelforge/tests/data/training_defaults/default.toml @@ -20,17 +20,17 @@ cooldown = 5 min_lr = 1e-8 threshold = 0.1 threshold_mode = "abs" -monitor = "val/energy/rmse" +monitor = "val/per_molecule_energy/rmse" interval = "epoch" [training.training_parameter.loss_parameter] -loss_property = ['energy', 'force'] -[loss_parameter.weight] -energy = 0.999 -force = 0.001 +loss_property = ['per_molecule_energy', 'force'] +[training.training_parameter.loss_parameter.weight] +per_molecule_energy = 0.999 +force = 0.001 [training.early_stopping] verbose = true -monitor = "val/energy/rmse" +monitor = "val/per_molecule_energy/rmse" min_delta = 0.001 patience = 50 diff --git a/modelforge/tests/test_models.py b/modelforge/tests/test_models.py index 12a4901c..d5fbaddf 100644 --- a/modelforge/tests/test_models.py +++ b/modelforge/tests/test_models.py @@ -120,15 +120,9 @@ def test_energy_scaling_and_offset(): dataset.setup() # -------------------------------# # initialize model - # read default parameters - from modelforge.train.training import return_toml_config - from modelforge.tests.data import potential_defaults - from importlib import resources - + # -------------------------------# config = load_configs("ani2x_without_ase", "qm9") - # Extract parameters - potential_parameter = config["potential"].get("potential_parameter", {}) import toml dataset_statistic = toml.load(dataset.dataset_statistic_filename) @@ -205,10 +199,9 @@ def test_dataset_statistic(model_name): from modelforge.dataset.utils import FirstComeFirstServeSplittingStrategy # read default parameters - config = load_configs(f"{model_name.lower()}_without_ase", "qm9") + config = load_configs(f"{model_name.lower()}", "qm9") # Extract parameters - potential_parameter = config["potential"].get("potential_parameter", {}) training_parameter = config["training"].get("training_parameter", {}) # test the self energy calculation on the QM9 dataset @@ -227,38 +220,56 @@ def test_dataset_statistic(model_name): import toml from openff.units import unit + # load dataset stastics from file dataset_statistic = toml.load(dataset.dataset_statistic_filename) + + # extract value to compare against toml_E_i_mean = unit.Quantity( dataset_statistic["atomic_energies_stats"]["E_i_mean"] ).m # set up training model - model = NeuralNetworkPotentialFactory.generate_model( + training_adapter = NeuralNetworkPotentialFactory.generate_model( use="training", - model_type=model_name, simulation_environment="PyTorch", - model_parameter=potential_parameter, + model_parameter=config["potential"], training_parameter=training_parameter, dataset_statistic=dataset_statistic, ) import torch import numpy as np + print(training_adapter.model.postprocessing.dataset_statistic) # check that the E_i_mean is the same than in the dataset statistics assert np.isclose( - toml_E_i_mean, model.model.postprocessing.per_atom_operations[0].mean + toml_E_i_mean, + unit.Quantity( + training_adapter.model.postprocessing.dataset_statistic[ + "atomic_energies_stats" + ]["E_i_mean"] + ).m, ) - torch.save(model.state_dict(), "model.pth") + torch.save(training_adapter.state_dict(), "model.pth") + # NOTE: we are passing dataset statistics explicit to the constructor + # this is not saved with the state_dict model = NeuralNetworkPotentialFactory.generate_model( use="inference", - model_type=model_name, simulation_environment="PyTorch", - model_parameter=potential_parameter, + model_parameter=config["potential"], + dataset_statistic=dataset_statistic, ) model.load_state_dict(torch.load("model.pth")) - assert np.isclose(toml_E_i_mean, model.postprocessing.per_atom_operations[0].mean) + + a = 7 + + assert np.isclose( + toml_E_i_mean, + unit.Quantity( + model.postprocessing.dataset_statistic["atomic_energies_stats"]["E_i_mean"] + ).m, + ) @pytest.mark.parametrize("model_name", _Implemented_NNPs.get_all_neural_network_names()) @@ -273,32 +284,28 @@ def test_energy_between_simulation_environments( # test the forward pass through each of the models # cast input and model to torch.float64 # read default parameters - config = load_configs(f"{model_name.lower()}_without_ase", "qm9") + config = load_configs(f"{model_name.lower()}", "qm9") - # Extract parameters - potential_parameter = config["potential"].get("potential_parameter", {}) # Setup loss from modelforge.train.training import return_toml_config torch.manual_seed(42) model = NeuralNetworkPotentialFactory.generate_model( use="inference", - model_type=model_name, simulation_environment="PyTorch", - model_parameter=potential_parameter, + model_parameter=config["potential"], ) - output_torch = model(nnp_input)["E"] + output_torch = model(nnp_input)["per_molecule_energy"] torch.manual_seed(42) model = NeuralNetworkPotentialFactory.generate_model( use="inference", - model_type=model_name, simulation_environment="JAX", - model_parameter=potential_parameter, + model_parameter=config["potential"], ) nnp_input = nnp_input.as_jax_namedtuple() - output_jax = model(nnp_input)["E"] + output_jax = model(nnp_input)["per_molecule_energy"] # test tat we get an energie per molecule assert np.isclose(output_torch.sum().detach().numpy(), output_jax.sum()) @@ -321,13 +328,8 @@ def test_forward_pass_with_all_datasets(model_name, dataset_name, datamodule_fac train_dataloader = dataset.train_dataloader() batch = next(iter(train_dataloader)) - # test that the neighborlist is correctly generated - # cast input and model to torch.float64 - # read default parameters config = load_configs(f"{model_name.lower()}", dataset_name.lower()) - # Extract parameters - potential_parameter = config["potential"].get("potential_parameter", {}) from modelforge.potential.models import NeuralNetworkPotentialFactory import toml @@ -335,23 +337,18 @@ def test_forward_pass_with_all_datasets(model_name, dataset_name, datamodule_fac model = NeuralNetworkPotentialFactory.generate_model( use="inference", - model_type=model_name, simulation_environment="PyTorch", - model_parameter=potential_parameter, + model_parameter=config["potential"], dataset_statistic=dataset_statistic, ) output = model(batch.nnp_input) - # test that the output has the following keys - assert "E" in output - assert 'per_atom_energy' in output - assert "mse" in output - assert "ase" in output + # test that the output has the following keys and follwing dim + assert "per_molecule_energy" in output + assert "per_atom_energy" in output - assert output["E"].shape[0] == 64 - assert output["mse"].shape[0] == 64 - assert output["ase"].shape == batch.nnp_input.atomic_numbers.shape - assert output['per_atom_energy'].shape == batch.nnp_input.atomic_numbers.shape + assert output["per_molecule_energy"].shape[0] == 64 + assert output["per_atom_energy"].shape == batch.nnp_input.atomic_numbers.shape pair_list = batch.nnp_input.pair_list # pairlist is in ascending order in row 0 @@ -370,6 +367,7 @@ def test_forward_pass( # read default parameters config = load_configs(f"{model_name.lower()}", "qm9") + nr_of_mols = nnp_input.atomic_subsystem_indices.unique().shape[0] # test the forward pass through each of the models model = NeuralNetworkPotentialFactory.generate_model( @@ -383,10 +381,8 @@ def test_forward_pass( output = model(nnp_input) # test tat we get an energie per molecule - nr_of_mols = nnp_input.atomic_subsystem_indices.unique().shape[0] assert len(output["per_molecule_energy"]) == nr_of_mols - # the batch consists of methane (CH4) and amamonium (NH3) # which have chemically equivalent hydrogens at the minimum geometry. # This has to be reflected in the atomic energies E_i, which @@ -395,18 +391,22 @@ def test_forward_pass( # assert that the following tensor has equal values for dim=0 index 1 to 4 and 6 to 8 assert torch.allclose( - output['per_atom_energy'][1:4], output['per_atom_energy'][1], atol=1e-5 + output["per_atom_energy"][1:4], output["per_atom_energy"][1], atol=1e-5 ) assert torch.allclose( - output['per_atom_energy'][6:8], output['per_atom_energy'][6], atol=1e-5 + output["per_atom_energy"][6:8], output["per_atom_energy"][6], atol=1e-5 ) # make sure that the total energy is \sum E_i assert torch.allclose( - output["per_molecule_energy"][0], output['per_atom_energy'][0:5].sum(dim=0), atol=1e-5 + output["per_molecule_energy"][0], + output["per_atom_energy"][0:5].sum(dim=0), + atol=1e-5, ) assert torch.allclose( - output["per_molecule_energy"][1], output['per_atom_energy'][5:9].sum(dim=0), atol=1e-5 + output["per_molecule_energy"][1], + output["per_atom_energy"][5:9].sum(dim=0), + atol=1e-5, ) @@ -418,42 +418,42 @@ def test_calculate_energies_and_forces(model_name, single_batch_with_batchsize_6 import torch # read default parameters - config = load_configs(f"{model_name.lower()}_without_ase", "qm9") + config = load_configs(f"{model_name.lower()}", "qm9") # Extract parameters - potential_parameter = config["potential"].get("potential_parameter", {}) training_parameter = config["training"].get("training_parameter", {}) + + # get batch nnp_input = single_batch_with_batchsize_64.nnp_input - # test the backward pass through each of the models - nr_of_mols = nnp_input.atomic_subsystem_indices.unique().shape[0] - nr_of_atoms_per_batch = nnp_input.atomic_subsystem_indices.shape[0] - # set seed manually + # test the pass through each of the models torch.manual_seed(42) model_inference = NeuralNetworkPotentialFactory.generate_model( use="inference", - model_type=model_name, - model_parameter=potential_parameter, + model_parameter=config["potential"], ) - E_inference = model_inference(nnp_input)["E"] + E_inference = model_inference(nnp_input)["per_molecule_energy"] # backpropagation F_inference = -torch.autograd.grad( E_inference.sum(), nnp_input.positions, create_graph=True, retain_graph=True )[0] + # make sure that dimension are as expected + nr_of_mols = nnp_input.atomic_subsystem_indices.unique().shape[0] + nr_of_atoms_per_batch = nnp_input.atomic_subsystem_indices.shape[0] + assert E_inference.shape == torch.Size([nr_of_mols]) assert F_inference.shape == (nr_of_atoms_per_batch, 3) # only one molecule torch.manual_seed(42) model_training = NeuralNetworkPotentialFactory.generate_model( use="training", - model_type=model_name, - model_parameter=potential_parameter, + model_parameter=config["potential"], training_parameter=training_parameter, ) - E_training = model_training.model.forward(nnp_input)["E"] + E_training = model_training.model.forward(nnp_input)["per_molecule_energy"] F_training = -torch.autograd.grad( E_training.sum(), nnp_input.positions, create_graph=True, retain_graph=True )[0] @@ -473,10 +473,7 @@ def test_calculate_energies_and_forces_with_jax( import torch # read default parameters - config = load_configs(f"{model_name.lower()}_without_ase", "qm9") - - # Extract parameters - potential_parameter = config["potential"].get("potential_parameter", {}) + config = load_configs(f"{model_name.lower()}", "qm9") nnp_input = single_batch_with_batchsize_64.nnp_input # test the backward pass through each of the models @@ -486,14 +483,13 @@ def test_calculate_energies_and_forces_with_jax( # The inference_model fixture now returns a function that expects an environment model = NeuralNetworkPotentialFactory.generate_model( use="inference", - model_type=model_name, - model_parameter=potential_parameter, + model_parameter=config["potential"], simulation_environment="JAX", ) nnp_input = nnp_input.as_jax_namedtuple() - result = model(nnp_input)["E"] + result = model(nnp_input)["per_molecule_energy"] import jax @@ -876,16 +872,12 @@ def test_casting(model_name, single_batch_with_batchsize_64): # cast input and model to torch.float64 # read default parameters - config = load_configs(f"{model_name.lower()}_without_ase", "qm9") - - # Extract parameters - potential_parameter = config["potential"].get("potential_parameter", {}) + config = load_configs(f"{model_name.lower()}", "qm9") model = NeuralNetworkPotentialFactory.generate_model( use="inference", - model_type=model_name, simulation_environment="PyTorch", - model_parameter=potential_parameter, + model_parameter=config["potential"], ) model = model.to(dtype=torch.float64) nnp_input = batch.nnp_input.to(dtype=torch.float64) @@ -895,9 +887,8 @@ def test_casting(model_name, single_batch_with_batchsize_64): # cast input and model to torch.float64 model = NeuralNetworkPotentialFactory.generate_model( use="inference", - model_type=model_name, simulation_environment="PyTorch", - model_parameter=potential_parameter, + model_parameter=config["potential"], ) model = model.to(dtype=torch.float32) nnp_input = batch.nnp_input.to(dtype=torch.float32) diff --git a/modelforge/train/training.py b/modelforge/train/training.py index 95726b24..960ddd37 100644 --- a/modelforge/train/training.py +++ b/modelforge/train/training.py @@ -1,12 +1,10 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau import lightning as pl -from typing import TYPE_CHECKING, Any, Union, Dict, Type, Optional +from typing import TYPE_CHECKING, Any, Union, Dict, Type, Optional, List import torch from loguru import logger as log from modelforge.dataset.dataset import BatchData -if TYPE_CHECKING: - from modelforge.potential.utils import BatchData import torchmetrics from torchmetrics.utilities import dim_zero_cat @@ -55,181 +53,102 @@ def compute(self) -> torch.Tensor: from torch import nn +from torch_scatter import scatter_sum -from abc import abstractmethod +class PerAtomToPerMoleculeError(nn.Module): -class Loss(nn.Module): - """ - Abstract base class for loss calculation in neural network potentials. - """ + def __init__(self): + from torch.nn import MSELoss - @abstractmethod - def calculate_loss( - self, predict_target: Dict[str, torch.Tensor], batch: BatchData - ) -> Dict[str, torch.Tensor]: - pass + super().__init__() + self.loss = MSELoss() + def forward( + self, predicted: torch.Tensor, true: torch.Tensor, batch + ) -> torch.Tensor: -class LossFactory(object): - """ - Factory class to create different types of loss functions. - """ + # squaared error + error_per_atom = torch.norm(predicted - true, dim=1) ** 2 - @staticmethod - def create_loss(loss_type: str, **kwargs) -> Type[Loss]: - """ - Creates an instance of the specified loss type. + # Aggregate error per molecule + error_per_molecule = scatter_sum( + error_per_atom, batch.nnp_input.atomic_subsystem_indices.long(), 0 + ) - Parameters - ---------- - loss_type : str - The type of loss function to create. - **kwargs : dict - Additional parameters for the loss function. + # divide by nnumber of atoms + return error_per_molecule / batch.metadata.atomic_subsystem_counts - Returns - ------- - Loss - An instance of the specified loss function. - """ - if loss_type == "EnergyAndForceLoss": - return EnergyAndForceLoss(**kwargs) - elif loss_type == "EnergyLoss": - return EnergyLoss() - else: - raise ValueError(f"Loss type {loss_type} not implemented.") +class PerMoleculeError(nn.Module): + def __init__(self): + from torch.nn import MSELoss -class EnergyLoss(Loss): - """ - Class to calculate the energy loss using Mean Squared Error (MSE). - """ + super().__init__() - def __init__( - self, - ): - """ - Initializes the EnergyLoss class. - """ + self.loss = MSELoss() - super().__init__() - from torch.nn import MSELoss + def forward( + self, predicted: torch.Tensor, true: torch.Tensor, batch + ) -> torch.Tensor: - self.mse_loss = MSELoss() + # divide by number of atoms + return self.loss(predicted, true) / batch.metadata.atomic_subsystem_counts - def calculate_loss( - self, predict_target: Dict[str, torch.Tensor], batch: Optional[BatchData] = None - ) -> Dict[str, torch.Tensor]: - """ - Calculates the energy loss. - Parameters - ---------- - predict_target : dict - Dictionary containing predicted and true values for energy. - batch : BatchData, optional - Batch of data, by default None. +class Loss(nn.Module): - Returns - ------- - dict - Dictionary containing combined loss, energy loss, and force loss. - """ - E_loss = self.mse_loss(predict_target["E_predict"], predict_target["E_true"]) + _SUPPORTED_PROPERTIES = ["per_molecule_energy", "force"] - return { - "combined_loss": E_loss, - "energy_loss": E_loss, - "force_loss": torch.zeros_like(E_loss), - } + def __init__(self, loss_porperty: List[str], weight: Dict[str, float]): + super().__init__() + from torch.nn import ModuleDict -class EnergyAndForceLoss(Loss): - """ - Class to calculate the combined loss for both energy and force predictions. + self.loss_property = loss_porperty + self.weight = weight - Attributes - ---------- - include_force : bool - Whether to include force in the loss calculation. - energy_weight : torch.Tensor - Weight for the energy loss component. - force_weight : torch.Tensor - Weight for the force loss component. - """ + self.loss = ModuleDict() - def __init__( - self, - include_force: bool = False, - energy_weight: float = 1.0, - force_weight: float = 1.0, - ): - """ - Initializes the EnergyAndForceLoss class. + for prop, w in weight.items(): + if prop in self._SUPPORTED_PROPERTIES: + if prop == "force": + self.loss[prop] = PerAtomToPerMoleculeError() + else: + self.loss[prop] = PerMoleculeError() + self.register_buffer(prop, torch.tensor(w)) + else: + raise NotImplementedError(f"Loss type {prop} not implemented.") - Parameters - ---------- - include_force : bool, optional - Whether to include force in the loss calculation, by default False. - energy_weight : float, optional - Weight for the energy loss component, by default 1.0. - force_weight : float, optional - Weight for the force loss component, by default 1.0. - """ - super().__init__() - self.include_force = include_force - self.register_buffer("energy_weight", torch.tensor(energy_weight)) - self.register_buffer("force_weight", torch.tensor(force_weight)) + def forward(self, predict_target: Dict[str, torch.Tensor], batch): - def calculate_loss( - self, predict_target: Dict[str, torch.Tensor], batch: BatchData - ) -> Dict[str, torch.Tensor]: - """ - Calculates the combined loss for both energy and force predictions. + loss = torch.zeros_like(predict_target["E_true"]) - Parameters - ---------- - predict_target : dict - Dictionary containing predicted and true values for energy and force. - Expected keys are 'E_predict', 'E_true', 'F_predict', 'F_true'. - batch : BatchData - Batch of data, including input features and target values. + for prop in self.loss_property: + loss += self.weight[prop] * self.loss[prop]( + predict_target[prop], predict_target[f"{prop}_true"], batch + ) + + return loss + +class LossFactory(object): + """ + Factory class to create different types of loss functions. + """ + + @staticmethod + def create_loss(loss_property: List[str], weight: Dict[str, float]) -> Type[Loss]: + """ + Creates an instance of the specified loss type. Returns ------- - dict - Dictionary containing combined loss, energy loss, and force loss. + Loss + An instance of the specified loss function. """ - from torch_scatter import scatter_sum - # Calculate per-atom force error - F_error_per_atom = ( - torch.norm(predict_target["F_predict"] - predict_target["F_true"], dim=1) - ** 2 - ) - # Aggregate force error per molecule - F_error_per_molecule = scatter_sum( - F_error_per_atom, batch.nnp_input.atomic_subsystem_indices.long(), 0 - ) - - # Scale factor for force loss - scale = self.force_weight / (3 * batch.metadata.atomic_subsystem_counts) - # Calculate energy loss - E_loss = ( - self.energy_weight - * (predict_target["E_predict"] - predict_target["E_true"]) ** 2 - ) - # Calculate force loss - F_loss = scale * F_error_per_molecule - # Combine energy and force losses - combined_loss = torch.mean(E_loss + F_loss) - return { - "combined_loss": combined_loss, - "energy_loss": E_loss, - "force_loss": F_loss, - } + return Loss(loss_property, weight) from torch.optim import Optimizer @@ -273,26 +192,21 @@ def __init__( super().__init__() self.save_hyperparameters() + # Extracting and instantiating the model from parameters - model_parameter_ = model_parameter.copy() - model_name = model_parameter_.pop("model_name", None) - if model_name is None: - raise ValueError( - "NNP name must be specified in nnp_parameters with key 'model_name'." - ) + model_name = model_parameter["model_name"] + # Get requested model class nnp_class: Type = _Implemented_NNPs.get_neural_network_class(model_name) - + # initialize model self.model = nnp_class( - **model_parameter_, + **model_parameter["core_parameter"], dataset_statistic=dataset_statistic, + postprocessing_parameter=model_parameter["postprocessing_parameter"], ) self.optimizer = optimizer self.learning_rate = lr self.lr_scheduler_config = lr_scheduler_config - self.loss_module = LossFactory.create_loss(**loss_parameter) - - self.unused_parameters = set() - self.are_unused_parameters_present = False + self.loss = LossFactory.create_loss(**loss_parameter) self.val_error = { "energy": MetricCollection( @@ -368,6 +282,7 @@ def _get_forces( retain_graph=True, )[0] F_predict = -1 * grad # Forces are the negative gradient of energy + return {"F_true": F_true, "F_predict": F_predict} def _get_energies(self, batch: "BatchData") -> Dict[str, torch.Tensor]: @@ -474,8 +389,9 @@ def training_step(self, batch: "BatchData", batch_idx: int) -> torch.Tensor: # calculate energy and forces predict_target = self._get_predictions(batch) + # calculate the loss - loss_dict = self.loss_module.calculate_loss(predict_target, batch) + loss_dict = self.loss_module(predict_target, batch) # Update and log metrics self._log_metrics(self.train_error, predict_target) # log loss @@ -782,7 +698,6 @@ def log_training_arguments( else: log.info(f"Using cache directory: {local_cache_dir}") - accelerator = training_config.get("accelerator", "cpu") if accelerator == "cpu": log.info(f"Using default accelerator: {accelerator}") From a0dde3199cec1f30cb01c9491f30cb73ba3cacf3 Mon Sep 17 00:00:00 2001 From: wiederm Date: Thu, 11 Jul 2024 15:57:20 +0200 Subject: [PATCH 45/78] update loss docstrings --- modelforge/potential/models.py | 46 +++++----- modelforge/train/training.py | 160 ++++++++++++++++++++++----------- 2 files changed, 133 insertions(+), 73 deletions(-) diff --git a/modelforge/potential/models.py b/modelforge/potential/models.py index 0a675c4f..24b4ec17 100644 --- a/modelforge/potential/models.py +++ b/modelforge/potential/models.py @@ -786,33 +786,33 @@ def _initialize_postprocessing( prostprocessing_sequence_names.append( "from_atom_to_molecule_reduction" ) + elif property == "general_postprocessing_operation": + # check if also self-energies are requested + if operations.get("calculate_molecular_self_energy", False): - # check if also self-energies are requested - if operations.get("calculate_molecular_self_energy", False): - - if self.dataset_statistic is None: - log.warning( - "Dataset statistics are required to calculate the molecular self-energies but haven't been provided." - ) - else: - atomic_self_energies = self.dataset_statistic[ - "atomic_self_energies" - ] + if self.dataset_statistic is None: + log.warning( + "Dataset statistics are required to calculate the molecular self-energies but haven't been provided." + ) + else: + atomic_self_energies = self.dataset_statistic[ + "atomic_self_energies" + ] - postprocessing_sequence.append( - CalculateAtomicSelfEnergy(atomic_self_energies) - ) - prostprocessing_sequence_names.append( - "calculate_molecular_self_energy" - ) + postprocessing_sequence.append( + CalculateAtomicSelfEnergy(atomic_self_energies) + ) + prostprocessing_sequence_names.append( + "calculate_molecular_self_energy" + ) - postprocessing_sequence.append( - FromAtomToMoleculeReduction( - per_atom_property_name="ase_tensor", - index_name="atomic_subsystem_indices", - output_name="per_molecule_self_energy", + postprocessing_sequence.append( + FromAtomToMoleculeReduction( + per_atom_property_name="ase_tensor", + index_name="atomic_subsystem_indices", + output_name="per_molecule_self_energy", + ) ) - ) # check if also self-energies are requested if operations.get("calculate_atomic_self_energy", False): diff --git a/modelforge/train/training.py b/modelforge/train/training.py index 960ddd37..91e4ced6 100644 --- a/modelforge/train/training.py +++ b/modelforge/train/training.py @@ -7,101 +7,138 @@ import torchmetrics -from torchmetrics.utilities import dim_zero_cat from typing import Optional -class LogLoss(torchmetrics.Metric): +from torch import nn +from torch_scatter import scatter_sum + + +class PerAtomToPerMoleculeError(nn.Module): """ - Custom metric to log the loss function. + Calculates the per-atom error and aggregates it to per-molecule mean squared error. - Attributes - ---------- - loss_per_batch : List[torch.Tensor] - List to store the loss for each batch. """ - def __init__(self) -> None: + def __init__(self): """ - Initializes the LogLoss class, setting up the state for the metric. + Initializes the PerAtomToPerMoleculeError class. """ + + from torch.nn import MSELoss + super().__init__() - self.add_state("loss_per_batch", default=[], dist_reduce_fx="cat") - def update(self, loss: torch.Tensor) -> None: + def forward( + self, predicted: torch.Tensor, true: torch.Tensor, batch + ) -> torch.Tensor: """ - Updates the metric state with the loss for a batch. + Computes the per-atom error and aggregates it to per-molecule mean squared error. Parameters ---------- - loss : torch.Tensor - The loss for a batch. - """ - self.loss_per_batch.append(loss.detach()) - - def compute(self) -> torch.Tensor: - """ - Computes the average loss over all batches in an epoch. + predicted : torch.Tensor + The predicted values. + true : torch.Tensor + The true values. + batch : Any + The batch data containing metadata and input information. Returns ------- torch.Tensor - The average loss for the epoch. + The aggregated per-molecule error. """ - mse_loss_per_epoch = dim_zero_cat(self.loss_per_batch) - return torch.mean(mse_loss_per_epoch) - - -from torch import nn -from torch_scatter import scatter_sum - - -class PerAtomToPerMoleculeError(nn.Module): - - def __init__(self): - from torch.nn import MSELoss - - super().__init__() - self.loss = MSELoss() - - def forward( - self, predicted: torch.Tensor, true: torch.Tensor, batch - ) -> torch.Tensor: # squaared error - error_per_atom = torch.norm(predicted - true, dim=1) ** 2 + per_atom_squared_error = torch.norm(predicted - true, dim=1) ** 2 # Aggregate error per molecule - error_per_molecule = scatter_sum( - error_per_atom, batch.nnp_input.atomic_subsystem_indices.long(), 0 + per_molecule_squared_error = scatter_sum( + per_atom_squared_error, batch.nnp_input.atomic_subsystem_indices.long(), 0 + ) + per_molecule_square_error_scaled = ( + per_molecule_squared_error / batch.metadata.atomic_subsystem_counts ) - # divide by nnumber of atoms - return error_per_molecule / batch.metadata.atomic_subsystem_counts + return per_molecule_square_error_scaled class PerMoleculeError(nn.Module): + """ + Calculates the per-molecule mean squared error. + + """ def __init__(self): - from torch.nn import MSELoss + """ + Initializes the PerMoleculeError class. + """ super().__init__() - self.loss = MSELoss() - def forward( self, predicted: torch.Tensor, true: torch.Tensor, batch ) -> torch.Tensor: + """ + Computes the per-molecule mean squared error. - # divide by number of atoms - return self.loss(predicted, true) / batch.metadata.atomic_subsystem_counts + Parameters + ---------- + predicted : torch.Tensor + The predicted values. + true : torch.Tensor + The true values. + batch : Any + The batch data containing metadata and input information. + + Returns + ------- + torch.Tensor + The mean per-molecule error. + """ + + per_molecule_squared_error = (predicted - true) ** 2 + per_molecule_square_error_scaled = ( + per_molecule_squared_error / batch.metadata.atomic_subsystem_counts + ) + + # average + return torch.mean(per_molecule_square_error_scaled) class Loss(nn.Module): + """ + Calculates the combined loss for energy and force predictions. + + Attributes + ---------- + loss_property : List[str] + List of properties to include in the loss calculation. + weight : Dict[str, float] + Dictionary containing the weights for each property in the loss calculation. + loss : nn.ModuleDict + Module dictionary containing the loss functions for each property. + """ _SUPPORTED_PROPERTIES = ["per_molecule_energy", "force"] def __init__(self, loss_porperty: List[str], weight: Dict[str, float]): + """ + Initializes the Loss class. + + Parameters + ---------- + loss_property : List[str] + List of properties to include in the loss calculation. + weight : Dict[str, float] + Dictionary containing the weights for each property in the loss calculation. + + Raises + ------ + NotImplementedError + If an unsupported loss type is specified. + """ super().__init__() from torch.nn import ModuleDict @@ -122,6 +159,21 @@ def __init__(self, loss_porperty: List[str], weight: Dict[str, float]): raise NotImplementedError(f"Loss type {prop} not implemented.") def forward(self, predict_target: Dict[str, torch.Tensor], batch): + """ + Calculates the combined loss for the specified properties. + + Parameters + ---------- + predict_target : Dict[str, torch.Tensor] + Dictionary containing predicted and true values for energy and force. + batch : Any + The batch data containing metadata and input information. + + Returns + ------- + torch.Tensor + The combined loss for the specified properties. + """ loss = torch.zeros_like(predict_target["E_true"]) @@ -142,6 +194,14 @@ class LossFactory(object): def create_loss(loss_property: List[str], weight: Dict[str, float]) -> Type[Loss]: """ Creates an instance of the specified loss type. + + Parameters + ---------- + loss_property : List[str] + List of properties to include in the loss calculation. + weight : Dict[str, float] + Dictionary containing the weights for each property in the loss calculation. + Returns ------- Loss From aeec947ad4bd2c30aeeb2b25377d868f5263b938 Mon Sep 17 00:00:00 2001 From: wiederm Date: Thu, 11 Jul 2024 16:46:22 +0200 Subject: [PATCH 46/78] fixing tests --- modelforge/potential/models.py | 4 +- .../tests/data/potential_defaults/schnet.toml | 1 + .../tests/data/training_defaults/default.toml | 6 +- modelforge/tests/test_ani.py | 12 +-- modelforge/tests/test_models.py | 56 ++++---------- modelforge/tests/test_painn.py | 74 +++++++------------ modelforge/tests/test_physnet.py | 22 +++--- modelforge/tests/test_sake.py | 33 +++++---- modelforge/tests/test_schnet.py | 67 +++++++++-------- modelforge/train/training.py | 13 ++-- 10 files changed, 124 insertions(+), 164 deletions(-) diff --git a/modelforge/potential/models.py b/modelforge/potential/models.py index 24b4ec17..2ac93276 100644 --- a/modelforge/potential/models.py +++ b/modelforge/potential/models.py @@ -688,7 +688,7 @@ def _input_checks(self, data: Union[NNPInput, NamedTuple]): class PostProcessing(torch.nn.Module): - _SUPPORTED_PROPERTIES = ["per_atom_energy"] + _SUPPORTED_PROPERTIES = ["per_atom_energy", "general_postprocessing_operation"] _SUPPORTED_OPERATIONS = ["normalize", "from_atom_to_molecule_reduction"] def __init__( @@ -893,7 +893,7 @@ def load_state_dict( # Prefix to remove prefix = "model." - excluded_keys = ["loss.per_molecule_energy", "loss.force"] + excluded_keys = ["loss.per_molecule_energy_error", "loss.per_atom_force_error"] # Create a new dictionary without the prefix in the keys if prefix exists if any(key.startswith(prefix) for key in state_dict.keys()): diff --git a/modelforge/tests/data/potential_defaults/schnet.toml b/modelforge/tests/data/potential_defaults/schnet.toml index 8626206f..c303069e 100644 --- a/modelforge/tests/data/potential_defaults/schnet.toml +++ b/modelforge/tests/data/potential_defaults/schnet.toml @@ -15,5 +15,6 @@ shared_interactions = false normalize = true from_atom_to_molecule_reduction = true keep_per_atom_property = true +[potential.postprocessing_parameter.general_postprocessing_operation] calculate_molecular_self_energy = true #calculate_atomic_self_energy = true diff --git a/modelforge/tests/data/training_defaults/default.toml b/modelforge/tests/data/training_defaults/default.toml index 664e32fa..34a3df99 100644 --- a/modelforge/tests/data/training_defaults/default.toml +++ b/modelforge/tests/data/training_defaults/default.toml @@ -24,10 +24,10 @@ monitor = "val/per_molecule_energy/rmse" interval = "epoch" [training.training_parameter.loss_parameter] -loss_property = ['per_molecule_energy', 'force'] +loss_property = ['per_molecule_energy_error', 'per_atom_force_error'] [training.training_parameter.loss_parameter.weight] -per_molecule_energy = 0.999 -force = 0.001 +per_molecule_energy_error = 0.999 +per_atom_force_error = 0.001 [training.early_stopping] verbose = true diff --git a/modelforge/tests/test_ani.py b/modelforge/tests/test_ani.py index 38c51edb..d7f04ec5 100644 --- a/modelforge/tests/test_ani.py +++ b/modelforge/tests/test_ani.py @@ -93,7 +93,7 @@ def test_forward_and_backward_using_torchani(): energy = model((species, coordinates)).energies derivative = torch.autograd.grad(energy.sum(), coordinates)[0] - force = -derivative + per_atom_force = -derivative def test_forward_and_backward(): @@ -113,7 +113,7 @@ def test_forward_and_backward(): model = ANI2x(**potential_parameter).to(device=device) energy = model(mf_input) derivative = torch.autograd.grad(energy["E"].sum(), mf_input.positions)[0] - force = -derivative + per_atom_force = -derivative def test_representation(): @@ -188,16 +188,12 @@ def test_representation_with_diagonal_batching(): reference_rbf_output, ani_d_ij = ( provide_reference_values_for_test_ani_test_compute_rsf_with_diagonal_batching() ) - assert torch.allclose( - calculated_rbf_output, reference_rbf_output, atol=1e-4 - ) + assert torch.allclose(calculated_rbf_output, reference_rbf_output, atol=1e-4) assert torch.allclose( ani_d_ij, d_ij.squeeze(1) * 10, atol=1e-4 ) # NOTE: unit mismatch - assert calculated_rbf_output.shape == torch.Size( - [20, radial_dist_divisions] - ) + assert calculated_rbf_output.shape == torch.Size([20, radial_dist_divisions]) def test_compare_angular_symmetry_features(): diff --git a/modelforge/tests/test_models.py b/modelforge/tests/test_models.py index d5fbaddf..ca8c900d 100644 --- a/modelforge/tests/test_models.py +++ b/modelforge/tests/test_models.py @@ -32,22 +32,18 @@ def test_JAX_wrapping(model_name, single_batch_with_batchsize_64): ) # read default parameters - config = load_configs(f"{model_name.lower()}_without_ase", "qm9") - - # Extract parameters - potential_parameter = config["potential"].get("potential_parameter", {}) + config = load_configs(f"{model_name.lower()}", "qm9") # inference model model = NeuralNetworkPotentialFactory.generate_model( use="inference", - model_type=model_name, simulation_environment="JAX", - model_parameter=potential_parameter, + model_parameter=config["potential"], ) assert "JAX" in str(type(model)) nnp_input = single_batch_with_batchsize_64.nnp_input.as_jax_namedtuple() - out = model(nnp_input)["E"] + out = model(nnp_input)["per_molecule_energy"] import jax grad_fn = jax.grad(lambda pos: out.sum()) # Create a gradient function @@ -65,20 +61,13 @@ def test_model_factory(model_name, simulation_environment): from modelforge.train.training import TrainingAdapter # read default parameters - config = load_configs(f"{model_name.lower()}_without_ase", "qm9") - - # Extract parameters - potential_parameter = config["potential"].get("potential_parameter", {}) - - # Setup loss - from modelforge.train.training import return_toml_config + config = load_configs(f"{model_name.lower()}", "qm9") # inference model model = NeuralNetworkPotentialFactory.generate_model( use="inference", - model_type=model_name, simulation_environment=simulation_environment, - model_parameter=potential_parameter, + model_parameter=config["potential"], ) assert ( model_name.upper() in str(type(model)).upper() @@ -90,9 +79,8 @@ def test_model_factory(model_name, simulation_environment): # training model model = NeuralNetworkPotentialFactory.generate_model( use="training", - model_type=model_name, simulation_environment=simulation_environment, - model_parameter=potential_parameter, + model_parameter=config["potential"], training_parameter=training_parameter, ) assert type(model) == TrainingAdapter @@ -163,28 +151,23 @@ def test_state_dict_saving_and_loading(model_name): import torch # read default parameters - config = load_configs(f"{model_name.lower()}_without_ase", "qm9") + config = load_configs(f"{model_name.lower()}", "qm9") # Extract parameters - potential_parameter = config["potential"].get("potential_parameter", {}) training_parameter = config["training"].get("training_parameter", {}) - # Setup loss - from modelforge.train.training import return_toml_config model1 = NeuralNetworkPotentialFactory.generate_model( use="training", - model_type=model_name, simulation_environment="PyTorch", - model_parameter=potential_parameter, + model_parameter=config["potential"], training_parameter=training_parameter, ) torch.save(model1.state_dict(), "model.pth") model2 = NeuralNetworkPotentialFactory.generate_model( use="inference", - model_type=model_name, simulation_environment="PyTorch", - model_parameter=potential_parameter, + model_parameter=config["potential"], ) model2.load_state_dict(torch.load("model.pth")) @@ -911,20 +894,13 @@ def test_equivariant_energies_and_forces( import torch from dataclasses import replace - # cast input and model to torch.float64 - # read default parameters - config = load_configs(f"{model_name}_without_ase", "qm9") - - # Extract parameters - potential_parameter = config["potential"].get("potential_parameter", {}) - # Setup loss - from modelforge.train.training import return_toml_config + # load default parameters + config = load_configs(f"{model_name}", "qm9") model = NeuralNetworkPotentialFactory.generate_model( use="inference", - model_type=model_name, simulation_environment=simulation_environment, - model_parameter=potential_parameter, + model_parameter=config["potential"], ) # define the symmetry operations @@ -940,7 +916,7 @@ def test_equivariant_energies_and_forces( # start the test # reference values nnp_input = single_batch_with_batchsize_64.nnp_input.to(dtype=torch.float64) - reference_result = model(nnp_input)["E"].to(dtype=torch.float64) + reference_result = model(nnp_input)["per_molecule_energy"].to(dtype=torch.float64) reference_forces = -torch.autograd.grad( reference_result.sum(), nnp_input.positions, @@ -949,7 +925,7 @@ def test_equivariant_energies_and_forces( # translation test translation_nnp_input = replace(nnp_input) translation_nnp_input.positions = translation(translation_nnp_input.positions) - translation_result = model(translation_nnp_input)["E"] + translation_result = model(translation_nnp_input)["per_molecule_energy"] assert torch.allclose( translation_result, reference_result, @@ -974,7 +950,7 @@ def test_equivariant_energies_and_forces( # rotation test rotation_input_data = replace(nnp_input) rotation_input_data.positions = rotation(rotation_input_data.positions) - rotation_result = model(rotation_input_data)["E"] + rotation_result = model(rotation_input_data)["per_molecule_energy"] for t, r in zip(rotation_result, reference_result): if not torch.allclose(t, r, atol=atol): @@ -1003,7 +979,7 @@ def test_equivariant_energies_and_forces( # reflection test reflection_input_data = replace(nnp_input) reflection_input_data.positions = reflection(reflection_input_data.positions) - reflection_result = model(reflection_input_data)["E"] + reflection_result = model(reflection_input_data)["per_molecule_energy"] reflection_forces = -torch.autograd.grad( reflection_result.sum(), reflection_input_data.positions, diff --git a/modelforge/tests/test_painn.py b/modelforge/tests/test_painn.py index 3800600d..8eb3be07 100644 --- a/modelforge/tests/test_painn.py +++ b/modelforge/tests/test_painn.py @@ -8,16 +8,16 @@ def test_forward(single_batch_with_batchsize_64): from modelforge.tests.test_models import load_configs # read default parameters - config = load_configs("painn_without_ase", "qm9") + config = load_configs("painn", "qm9") - # Extract parameters - potential_parameter = config["potential"].get("potential_parameter", {}) - - painn = PaiNN(**potential_parameter) + painn = PaiNN( + **config["potential"]["core_parameter"], + postprocessing_parameter=config["potential"]["postprocessing_parameter"], + ) assert painn is not None, "PaiNN model should be initialized." nnp_input = single_batch_with_batchsize_64.nnp_input.to(dtype=torch.float32) - energy = painn(nnp_input)["E"] + energy = painn(nnp_input)["per_molecule_energy"] nr_of_mols = nnp_input.atomic_subsystem_indices.unique().shape[0] assert ( @@ -33,10 +33,7 @@ def test_equivariance(single_batch_with_batchsize_64): from modelforge.tests.test_models import load_configs # read default parameters - config = load_configs("painn_without_ase", "qm9") - - # Extract parameters - potential_parameter = config["potential"].get("potential_parameter", {}) + config = load_configs("painn", "qm9") # define a rotation matrix in 3D that rotates by 90 degrees around the z-axis # (clockwise when looking along the z-axis towards the origin) @@ -44,7 +41,10 @@ def test_equivariance(single_batch_with_batchsize_64): [[0.0, 1.0, 0.0], [-1.0, 0.0, 0.0], [0.0, 0.0, 1.0]], dtype=torch.float64 ) - painn = PaiNN(**potential_parameter).to(torch.float64) + painn = PaiNN( + **config["potential"]["core_parameter"], + postprocessing_parameter=config["potential"]["postprocessing_parameter"], + ).to(torch.float64) methane_input = single_batch_with_batchsize_64.nnp_input.to(dtype=torch.float64) perturbed_methane_input = replace(methane_input) perturbed_methane_input.positions = torch.matmul( @@ -153,52 +153,30 @@ def test_equivariance(single_batch_with_batchsize_64): from modelforge.tests.test_schnet import setup_single_methane_input -def setup_representation( - cutoff, nr_atom_basis, number_of_gaussians, nr_of_interactions -): - # ------------------------------------ # - # set up the modelforge Painn representation model - # which means that we only want to call the - # _transform_input() method - from modelforge.potential.painn import PaiNN - - return PaiNN( - max_Z=100, - number_of_atom_features=nr_atom_basis, - number_of_interaction_modules=nr_of_interactions, - number_of_radial_basis_functions=number_of_gaussians, - cutoff=cutoff, - shared_interactions=False, - shared_filters=False, - processing_operation=[], - readout_operation=[ - { - "step": "from_atom_to_molecule", - "mode": "sum", - "in": per_atom_energy, - "index_key": "atomic_subsystem_indices", - "out": "E", - } - ], - ) - - def test_compare_representation(): # ---------------------------------------- # # setup the PaiNN model # ---------------------------------------- # from openff.units import unit from .precalculated_values import load_precalculated_painn_results + from modelforge.tests.test_models import load_configs + + # read default parameters + config = load_configs("painn", "qm9") - cutoff = unit.Quantity(5.0, unit.angstrom) - nr_atom_basis = 8 - number_of_gaussians = 5 - nr_of_interactions = 3 torch.manual_seed(1234) - model = setup_representation( - cutoff, nr_atom_basis, number_of_gaussians, nr_of_interactions - ).double() + # override defaults to match reference implementation in spk + config["potential"]["core_parameter"]["max_Z"] = 100 + config["potential"]["core_parameter"]["number_of_atom_features"] = 8 + config["potential"]["core_parameter"]["number_of_radial_basis_functions"] = 5 + + # initialize model + model = PaiNN( + **config["potential"]["core_parameter"], + postprocessing_parameter=config["potential"]["postprocessing_parameter"], + ).to(torch.float64) + # ------------------------------------ # # set up the input for the Painn model input = setup_single_methane_input() diff --git a/modelforge/tests/test_physnet.py b/modelforge/tests/test_physnet.py index 21e1a6c3..4db86d34 100644 --- a/modelforge/tests/test_physnet.py +++ b/modelforge/tests/test_physnet.py @@ -5,11 +5,12 @@ def test_init(): from modelforge.tests.test_models import load_configs # read default parameters - config = load_configs(f"physnet_without_ase", "qm9") - # Extract parameters - potential_parameter = config["potential"].get("potential_parameter", {}) + config = load_configs(f"physnet", "qm9") - model = PhysNet(**potential_parameter) + model = PhysNet( + **config["potential"]["core_parameter"], + postprocessing_parameter=config["potential"]["postprocessing_parameter"], + ) def test_forward(single_batch_with_batchsize_64): @@ -20,15 +21,16 @@ def test_forward(single_batch_with_batchsize_64): from modelforge.tests.test_models import load_configs # read default parameters - config = load_configs(f"physnet_without_ase", "qm9") - # Extract parameters - potential_parameter = config["potential"].get("potential_parameter", {}) + config = load_configs(f"physnet", "qm9") # Extract parameters - potential_parameter["number_of_modules"] = 1 - potential_parameter["number_of_interaction_residual"] = 1 + config["potential"]["core_parameter"]["number_of_modules"] = 1 + config["potential"]["core_parameter"]["number_of_interaction_residual"] = 1 - model = PhysNet(**potential_parameter) + model = PhysNet( + **config["potential"]["core_parameter"], + postprocessing_parameter=config["potential"]["postprocessing_parameter"], + ) model = model.to(torch.float32) print(model) yhat = model(single_batch_with_batchsize_64.nnp_input.to(dtype=torch.float32)) diff --git a/modelforge/tests/test_sake.py b/modelforge/tests/test_sake.py index ccfe4771..94c2709a 100644 --- a/modelforge/tests/test_sake.py +++ b/modelforge/tests/test_sake.py @@ -14,23 +14,25 @@ ON_MAC = platform == "darwin" -def test_SAKE_init(): +def test_init(): """Test initialization of the SAKE neural network potential.""" from modelforge.tests.test_models import load_configs # read default parameters - config = load_configs(f"sake_without_ase", "qm9") - # Extract parameters - potential_parameter = config["potential"].get("potential_parameter", {}) + config = load_configs(f"sake", "qm9") - sake = SAKE(**potential_parameter) + # initialize model + sake = SAKE( + **config["potential"]["core_parameter"], + postprocessing_parameter=config["potential"]["postprocessing_parameter"], + ) assert sake is not None, "SAKE model should be initialized." from openff.units import unit -def test_sake_forward(single_batch_with_batchsize_64): +def test_forward(single_batch_with_batchsize_64): """ Test the forward pass of the SAKE model. """ @@ -40,12 +42,13 @@ def test_sake_forward(single_batch_with_batchsize_64): from modelforge.tests.test_models import load_configs # read default parameters - config = load_configs(f"sake_without_ase", "qm9") - # Extract parameters - potential_parameter = config["potential"].get("potential_parameter", {}) + config = load_configs(f"sake", "qm9") - sake = SAKE(**potential_parameter) - energy = sake(methane)["E"] + sake = SAKE( + **config["potential"]["core_parameter"], + postprocessing_parameter=config["potential"]["postprocessing_parameter"], + ) + energy = sake(methane)["per_molecule_energy"] nr_of_mols = methane.atomic_subsystem_indices.unique().shape[0] assert ( @@ -53,7 +56,7 @@ def test_sake_forward(single_batch_with_batchsize_64): ) # Assuming energy is calculated per sample in the batch -def test_sake_interaction_forward(): +def test_interaction_forward(): nr_atoms = 41 nr_atom_basis = 47 geometry_basis = 3 @@ -84,7 +87,7 @@ def test_sake_interaction_forward(): @pytest.mark.parametrize("eq_atol", [3e-1]) @pytest.mark.parametrize("h_atol", [8e-2]) -def test_sake_layer_equivariance(h_atol, eq_atol, single_batch_with_batchsize_64): +def test_layer_equivariance(h_atol, eq_atol, single_batch_with_batchsize_64): import torch from modelforge.potential.sake import SAKE from dataclasses import replace @@ -409,7 +412,7 @@ def test_sake_layer_against_reference(include_self_pairs, v_is_none): ) -def test_sake_model_against_reference(single_batch_with_batchsize_1): +def test_model_against_reference(single_batch_with_batchsize_1): nr_heads = 5 nr_atom_basis = 11 max_Z = 13 @@ -431,7 +434,7 @@ def test_sake_model_against_reference(single_batch_with_batchsize_1): { "step": "from_atom_to_molecule", "mode": "sum", - "in": 'per_atom_energy', + "in": "per_atom_energy", "index_key": "atomic_subsystem_indices", "out": "E", } diff --git a/modelforge/tests/test_schnet.py b/modelforge/tests/test_schnet.py index 8e29df07..04353bf7 100644 --- a/modelforge/tests/test_schnet.py +++ b/modelforge/tests/test_schnet.py @@ -31,7 +31,7 @@ def initialize_model( { "step": "from_atom_to_molecule", "mode": "sum", - "in": 'per_atom_energy', + "in": "per_atom_energy", "index_key": "atomic_subsystem_indices", "out": "E", } @@ -45,11 +45,13 @@ def test_init(): from modelforge.tests.test_models import load_configs - # read default parameters - config = load_configs(f"schnet_without_ase", "qm9") - # Extract parameters - potential_parameter = config["potential"].get("potential_parameter", {}) - schnet = SchNet(**potential_parameter) + # load default parameters + config = load_configs(f"schnet", "qm9") + # initialize model + schnet = SchNet( + **config["potential"]["core_parameter"], + postprocessing_parameter=config["potential"]["postprocessing_parameter"], + ) assert schnet is not None, "Schnet model should be initialized." @@ -159,15 +161,22 @@ def test_compare_forward(): # ---------------------------------------- # # test the implementation of the representation part of the PaiNN model # ---------------------------------------- # - from openff.units import unit + from modelforge.potential.schnet import SchNet + + from modelforge.tests.test_models import load_configs - cutoff = unit.Quantity(5.0, unit.angstrom) - number_of_atom_features = 12 - n_rbf = 5 - nr_of_interactions = 3 torch.manual_seed(1234) - modelforge_schnet = initialize_model( - cutoff, number_of_atom_features, n_rbf, nr_of_interactions + # load default parameters + config = load_configs(f"schnet", "qm9") + + # override default parameters + config["potential"]["core_parameter"]["number_of_atom_features"] = 12 + config["potential"]["core_parameter"]["number_of_radial_basis_functions"] = 5 + + # initialize model + schnet = SchNet( + **config["potential"]["core_parameter"], + postprocessing_parameter=config["potential"]["postprocessing_parameter"], ).double() # ------------------------------------ # @@ -180,10 +189,10 @@ def test_compare_forward(): spk_input = input["spk_methane_input"] model_input = input["modelforge_methane_input"] - modelforge_schnet.input_preparation._input_checks(model_input) + schnet.input_preparation._input_checks(model_input) - pairlist_output = modelforge_schnet.input_preparation.prepare_inputs(model_input) - prepared_input = modelforge_schnet.core_module._model_specific_input_preparation( + pairlist_output = schnet.input_preparation.prepare_inputs(model_input) + prepared_input = schnet.core_module._model_specific_input_preparation( model_input, pairlist_output ) @@ -225,8 +234,10 @@ def test_compare_forward(): ], dtype=torch.float64, ) - calculated_phi_ij = modelforge_schnet.core_module.schnet_representation_module.radial_symmetry_function_module( - d_ij.unsqueeze(1) / 10 + calculated_phi_ij = ( + schnet.core_module.schnet_representation_module.radial_symmetry_function_module( + d_ij.unsqueeze(1) / 10 + ) ) # NOTE: converting to nm assert torch.allclose(reference_phi_ij, calculated_phi_ij, atol=1e-3) @@ -258,10 +269,8 @@ def test_compare_forward(): ], dtype=torch.float64, ) - calculated_fcut = ( - modelforge_schnet.core_module.schnet_representation_module.cutoff_module( - d_ij / 10 - ) + calculated_fcut = schnet.core_module.schnet_representation_module.cutoff_module( + d_ij / 10 ) # NOTE: converting to nm assert torch.allclose(reference_fcut, calculated_fcut, atol=1e-4) @@ -271,21 +280,17 @@ def test_compare_forward(): # Check full pass torch.manual_seed(1234) - for i in range(nr_of_interactions): - modelforge_schnet.core_module.interaction_modules[ - i - ].intput_to_feature.reset_parameters() + for i in range(3): + schnet.core_module.interaction_modules[i].intput_to_feature.reset_parameters() for j in range(2): - modelforge_schnet.core_module.interaction_modules[i].feature_to_output[ + schnet.core_module.interaction_modules[i].feature_to_output[ j ].reset_parameters() - modelforge_schnet.core_module.interaction_modules[i].filter_network[ + schnet.core_module.interaction_modules[i].filter_network[ j ].reset_parameters() - calculated_results = modelforge_schnet.core_module.forward( - model_input, pairlist_output - ) + calculated_results = schnet.core_module.forward(model_input, pairlist_output) reference_results = load_precalculated_schnet_results() assert ( reference_results["scalar_representation"].shape diff --git a/modelforge/train/training.py b/modelforge/train/training.py index 91e4ced6..9430d2f4 100644 --- a/modelforge/train/training.py +++ b/modelforge/train/training.py @@ -14,7 +14,7 @@ from torch_scatter import scatter_sum -class PerAtomToPerMoleculeError(nn.Module): +class FromPerAtomToPerMoleculeError(nn.Module): """ Calculates the per-atom error and aggregates it to per-molecule mean squared error. @@ -121,7 +121,7 @@ class Loss(nn.Module): Module dictionary containing the loss functions for each property. """ - _SUPPORTED_PROPERTIES = ["per_molecule_energy", "force"] + _SUPPORTED_PROPERTIES = ["per_molecule_energy_error", "per_atom_force_error"] def __init__(self, loss_porperty: List[str], weight: Dict[str, float]): """ @@ -150,8 +150,8 @@ def __init__(self, loss_porperty: List[str], weight: Dict[str, float]): for prop, w in weight.items(): if prop in self._SUPPORTED_PROPERTIES: - if prop == "force": - self.loss[prop] = PerAtomToPerMoleculeError() + if prop == "per_atom_force_error": + self.loss[prop] = FromPerAtomToPerMoleculeError() else: self.loss[prop] = PerMoleculeError() self.register_buffer(prop, torch.tensor(w)) @@ -165,7 +165,7 @@ def forward(self, predict_target: Dict[str, torch.Tensor], batch): Parameters ---------- predict_target : Dict[str, torch.Tensor] - Dictionary containing predicted and true values for energy and force. + Dictionary containing predicted and true values for energy and per_atom_force. batch : Any The batch data containing metadata and input information. @@ -241,7 +241,6 @@ def __init__( lr : float The learning rate for the optimizer. loss_module : Loss, optional - Whether to include force in the loss function, by default False. optimizer : Type[Optimizer], optional The optimizer class to use for training, by default torch.optim.AdamW. """ @@ -424,7 +423,7 @@ def _log_metrics( predict_target["E_predict"].detach(), predict_target["E_true"].detach(), ) - if property == "force": + if property == "per_atom_force": error_log( predict_target["F_predict"].detach(), predict_target["F_true"].detach(), From d7e819c83f4df1755911a662f0ed7bb9759f9d60 Mon Sep 17 00:00:00 2001 From: wiederm Date: Thu, 11 Jul 2024 19:15:24 +0200 Subject: [PATCH 47/78] fixing tests and names --- modelforge/dataset/dataset.py | 37 +++++++++---------- modelforge/dataset/utils.py | 4 +- modelforge/potential/models.py | 16 ++++---- modelforge/potential/processing.py | 6 +-- .../tests/data/potential_defaults/schnet.toml | 1 - .../tests/data/training_defaults/default.toml | 4 +- modelforge/train/training.py | 12 ++++-- 7 files changed, 42 insertions(+), 38 deletions(-) diff --git a/modelforge/dataset/dataset.py b/modelforge/dataset/dataset.py index 8cd0f26c..ed2f363c 100644 --- a/modelforge/dataset/dataset.py +++ b/modelforge/dataset/dataset.py @@ -1091,10 +1091,10 @@ def prepare_data( f"Loading dataset statistics from disk: {self.dataset_statistic_filename}" ) atomic_self_energies = self._read_atomic_self_energies() - atomic_energies_stats = self._read_atomic_energies_stats() + training_dataset_statistics = self._read_atomic_energies_stats() else: atomic_self_energies = None - atomic_energies_stats = None + training_dataset_statistics = None # obtain the atomic self energies from the dataset dataset_ase = dataset.atomic_self_energies.energies @@ -1113,17 +1113,17 @@ def prepare_data( # calculate the dataset statistic of the dataset # This is done __after__ self energies are removed (if requested) - if atomic_energies_stats is None: + if training_dataset_statistics is None: from modelforge.dataset.utils import calculate_mean_and_variance - atomic_energies_stats = calculate_mean_and_variance(torch_dataset) + training_dataset_statistics = calculate_mean_and_variance(torch_dataset) # wrap everything in a dictionary and save it to disk dataset_statistic = { "atomic_self_energies": atomic_self_energies, - "atomic_energies_stats": atomic_energies_stats, + "training_dataset_statistics": training_dataset_statistics, } - if atomic_self_energies and atomic_energies_stats: + if atomic_self_energies and training_dataset_statistics: log.info(dataset_statistic) # save dataset_statistic dictionary to disk as yaml files self._log_dataset_statistic(dataset_statistic) @@ -1138,7 +1138,6 @@ def prepare_data( def _log_dataset_statistic(self, dataset_statistic): """Save the dataset statistics to a file with units""" import toml - # cast units to string atomic_self_energies = { @@ -1146,14 +1145,18 @@ def _log_dataset_statistic(self, dataset_statistic): for key, value in dataset_statistic["atomic_self_energies"].items() } # cast float and kJ/mol on pytorch tensors and then convert to string - atomic_energies_stats = { - key: str(unit.Quantity(value.item(), unit.kilojoule_per_mole)) if isinstance(value, torch.Tensor) else value - for key, value in dataset_statistic["atomic_energies_stats"].items() + training_dataset_statistics = { + key: ( + str(unit.Quantity(value.item(), unit.kilojoule_per_mole)) + if isinstance(value, torch.Tensor) + else value + ) + for key, value in dataset_statistic["training_dataset_statistics"].items() } dataset_statistic = { "atomic_self_energies": atomic_self_energies, - "atomic_energies_stats": atomic_energies_stats, + "training_dataset_statistics": training_dataset_statistics, } toml.dump( dataset_statistic, @@ -1170,17 +1173,13 @@ def _read_atomic_self_energies(self) -> Dict[str, Quantity]: """Read the atomic self energies from a file.""" from modelforge.potential.processing import load_atomic_self_energies - return load_atomic_self_energies( - self.dataset_statistic_filename - ) + return load_atomic_self_energies(self.dataset_statistic_filename) def _read_atomic_energies_stats(self) -> Dict[str, torch.Tensor]: """Read the atomic energies statistics from a file.""" from modelforge.potential.processing import load_atomic_energies_stats - return load_atomic_energies_stats( - self.dataset_statistic_filename - ) + return load_atomic_energies_stats(self.dataset_statistic_filename) def _create_torch_dataset(self, dataset): """Create a PyTorch dataset from the provided dataset instance.""" @@ -1204,9 +1203,7 @@ def _calculate_atomic_self_energies( # Use provided ase dictionary if self.dict_atomic_self_energies: - log.info( - "Using atomic self energies from the provided dictionary." - ) + log.info("Using atomic self energies from the provided dictionary.") return self.dict_atomic_self_energies # Use regression to calculate ase diff --git a/modelforge/dataset/utils.py b/modelforge/dataset/utils.py index 7efd77c0..42b36f34 100644 --- a/modelforge/dataset/utils.py +++ b/modelforge/dataset/utils.py @@ -124,8 +124,8 @@ def calculate_mean_and_variance( online_estimator.update(E_scaled) stats = { - "E_i_mean": online_estimator.mean, - "E_i_stddev": online_estimator.stddev, + "per_atom_energy_mean": online_estimator.mean, + "per_atom_energy_stddev": online_estimator.stddev, } log.info(f"Mean and standard deviation of the dataset:{stats}") return stats diff --git a/modelforge/potential/models.py b/modelforge/potential/models.py index 2ac93276..9c5841a8 100644 --- a/modelforge/potential/models.py +++ b/modelforge/potential/models.py @@ -725,13 +725,15 @@ def _get_mean_and_stddev_of_dataset(self) -> Tuple[float, float]: f"No mean and stddev provided for dataset. Setting to default value {mean=} and {stddev=}!" ) else: - atomic_energies_stats = self.dataset_statistic["atomic_energies_stats"] - mean = unit.Quantity(atomic_energies_stats["E_i_mean"]).m_as( - unit.kilojoule_per_mole - ) - stddev = unit.Quantity(atomic_energies_stats["E_i_stddev"]).m_as( - unit.kilojoule_per_mole - ) + training_dataset_statistics = self.dataset_statistic[ + "training_dataset_statistics" + ] + mean = unit.Quantity( + training_dataset_statistics["per_atom_energy_mean"] + ).m_as(unit.kilojoule_per_mole) + stddev = unit.Quantity( + training_dataset_statistics["per_atom_energy_stddev"] + ).m_as(unit.kilojoule_per_mole) return mean, stddev def _initialize_postprocessing( diff --git a/modelforge/potential/processing.py b/modelforge/potential/processing.py index 4e91216d..9bde7562 100644 --- a/modelforge/potential/processing.py +++ b/modelforge/potential/processing.py @@ -22,12 +22,12 @@ def load_atomic_energies_stats(path: str) -> Dict[str, unit.Quantity]: energy_statistic = toml.load(open(path, "r")) # convert values to tensor - atomic_energies_stats = { + training_dataset_statistics = { key: unit.Quantity(value) - for key, value in energy_statistic["atomic_energies_stats"].items() + for key, value in energy_statistic["training_dataset_statistics"].items() } - return atomic_energies_stats + return training_dataset_statistics class FromAtomToMoleculeReduction(torch.nn.Module): diff --git a/modelforge/tests/data/potential_defaults/schnet.toml b/modelforge/tests/data/potential_defaults/schnet.toml index c303069e..f5b0094d 100644 --- a/modelforge/tests/data/potential_defaults/schnet.toml +++ b/modelforge/tests/data/potential_defaults/schnet.toml @@ -17,4 +17,3 @@ from_atom_to_molecule_reduction = true keep_per_atom_property = true [potential.postprocessing_parameter.general_postprocessing_operation] calculate_molecular_self_energy = true -#calculate_atomic_self_energy = true diff --git a/modelforge/tests/data/training_defaults/default.toml b/modelforge/tests/data/training_defaults/default.toml index 34a3df99..7dbaa5dd 100644 --- a/modelforge/tests/data/training_defaults/default.toml +++ b/modelforge/tests/data/training_defaults/default.toml @@ -24,9 +24,9 @@ monitor = "val/per_molecule_energy/rmse" interval = "epoch" [training.training_parameter.loss_parameter] -loss_property = ['per_molecule_energy_error', 'per_atom_force_error'] +loss_property = ['per_molecule_energy_error', 'per_atom_force_error'] # use . [training.training_parameter.loss_parameter.weight] -per_molecule_energy_error = 0.999 +per_molecule_energy_error = 0.999 #NOTE: reciproce units per_atom_force_error = 0.001 [training.early_stopping] diff --git a/modelforge/train/training.py b/modelforge/train/training.py index 9430d2f4..1b8328b6 100644 --- a/modelforge/train/training.py +++ b/modelforge/train/training.py @@ -909,9 +909,15 @@ def perform_training( import toml dataset_statistic = toml.load(dm.dataset_statistic_filename) - log.info(f"Setting E_i_mean and E_i_stddev for {model_name}") - log.info(f"E_i_mean: {dataset_statistic['atomic_energies_stats']['E_i_mean']}") - log.info(f"E_i_stddev: {dataset_statistic['atomic_energies_stats']['E_i_stddev']}") + log.info( + f"Setting per_atom_energy_mean and per_atom_energy_stddev for {model_name}" + ) + log.info( + f"per_atom_energy_mean: {dataset_statistic['training_dataset_statistics']['per_atom_energy_mean']}" + ) + log.info( + f"per_atom_energy_stddev: {dataset_statistic['training_dataset_statistics']['per_atom_energy_stddev']}" + ) # Set up model model = NeuralNetworkPotentialFactory.generate_model( From cab0ed98227658652168168f4ea17749e9a3775c Mon Sep 17 00:00:00 2001 From: wiederm Date: Thu, 11 Jul 2024 20:10:53 +0200 Subject: [PATCH 48/78] fix test --- modelforge/potential/models.py | 35 ++++++------ modelforge/tests/test_ani.py | 22 +++++--- modelforge/tests/test_dataset.py | 8 ++- modelforge/tests/test_models.py | 92 +++++++++++++++++++++----------- modelforge/tests/test_sake.py | 14 +++-- modelforge/tests/test_schnet.py | 7 +-- 6 files changed, 113 insertions(+), 65 deletions(-) diff --git a/modelforge/potential/models.py b/modelforge/potential/models.py index 9c5841a8..20d51d79 100644 --- a/modelforge/potential/models.py +++ b/modelforge/potential/models.py @@ -816,23 +816,23 @@ def _initialize_postprocessing( ) ) - # check if also self-energies are requested - if operations.get("calculate_atomic_self_energy", False): - if self.dataset_statistic is None: - log.warning( - "Dataset statistics are required to calculate the molecular self-energies but haven't been provided." - ) - else: - atomic_self_energies = self.dataset_statistic[ - "atomic_self_energies" - ] + # check if also self-energies are requested + elif operations.get("calculate_atomic_self_energy", False): + if self.dataset_statistic is None: + log.warning( + "Dataset statistics are required to calculate the molecular self-energies but haven't been provided." + ) + else: + atomic_self_energies = self.dataset_statistic[ + "atomic_self_energies" + ] - postprocessing_sequence.append( - CalculateAtomicSelfEnergy(atomic_self_energies)() - ) - prostprocessing_sequence_names.append( - "calculate_atomic_self_energy" - ) + postprocessing_sequence.append( + CalculateAtomicSelfEnergy(atomic_self_energies)() + ) + prostprocessing_sequence_names.append( + "calculate_atomic_self_energy" + ) log.debug(prostprocessing_sequence_names) @@ -845,7 +845,8 @@ def forward(self, data: Dict[str, torch.Tensor]): # NOTE: this is not very elegant, but I am unsure how to do this better # I am currently directly writing new keys and values in the data dictionary - for property in list(data.keys()): + property_keys = list(self.registered_chained_operations.keys()) + for property in property_keys: if property in self._registered_properties: self.registered_chained_operations[property](data) diff --git a/modelforge/tests/test_ani.py b/modelforge/tests/test_ani.py index d7f04ec5..3dddcf2f 100644 --- a/modelforge/tests/test_ani.py +++ b/modelforge/tests/test_ani.py @@ -104,15 +104,20 @@ def test_forward_and_backward(): import torch # read default parameters - config = load_configs("ani2x_without_ase", "qm9") - # Extract parameters - potential_parameter = config["potential"].get("potential_parameter", {}) + config = load_configs("ani2x", "qm9") _, _, _, mf_input = setup_two_methanes() device = torch.device("cpu") - model = ANI2x(**potential_parameter).to(device=device) + + # initialize model + model = ANI2x( + **config["potential"]["core_parameter"], + postprocessing_parameter=config["potential"]["postprocessing_parameter"], + ).to(device=device) energy = model(mf_input) - derivative = torch.autograd.grad(energy["E"].sum(), mf_input.positions)[0] + derivative = torch.autograd.grad( + energy["per_molecule_energy"].sum(), mf_input.positions + )[0] per_atom_force = -derivative @@ -298,12 +303,15 @@ def test_compare_aev(): from modelforge.tests.test_models import load_configs # read default parameters - config = load_configs("ani2x_without_ase", "qm9") + config = load_configs("ani2x", "qm9") # Extract parameters potential_parameter = config["potential"].get("potential_parameter", {}) - mf_model = ANI2x(**potential_parameter) + mf_model = ANI2x( + **config["potential"]["core_parameter"], + postprocessing_parameter=config["potential"]["postprocessing_parameter"], + ) # perform input checks mf_model.input_preparation._input_checks(mf_input) # prepare the input for the forward pass diff --git a/modelforge/tests/test_dataset.py b/modelforge/tests/test_dataset.py index 67dc3e34..ff8bb2f7 100644 --- a/modelforge/tests/test_dataset.py +++ b/modelforge/tests/test_dataset.py @@ -770,12 +770,16 @@ def test_energy_postprocessing(): from openff.units import unit assert np.isclose( - unit.Quantity(dataset_statistic["atomic_energies_stats"]["E_i_mean"]).m, + unit.Quantity( + dataset_statistic["training_dataset_statistics"]["per_atom_energy_mean"] + ).m, -402.916561, ) assert np.isclose( - unit.Quantity(dataset_statistic["atomic_energies_stats"]["E_i_stddev"]).m, + unit.Quantity( + dataset_statistic["training_dataset_statistics"]["per_atom_energy_stddev"] + ).m, 25.013382078330697, ) diff --git a/modelforge/tests/test_models.py b/modelforge/tests/test_models.py index ca8c900d..0c18fd38 100644 --- a/modelforge/tests/test_models.py +++ b/modelforge/tests/test_models.py @@ -88,10 +88,13 @@ def test_model_factory(model_name, simulation_environment): def test_energy_scaling_and_offset(): # setup test dataset - from modelforge.dataset.dataset import DataModule from modelforge.potential.ani import ANI2x + from modelforge.dataset.dataset import DataModule + import torch + # prepare reference value + # get methane input # test the self energy calculation on the QM9 dataset from modelforge.dataset.utils import FirstComeFirstServeSplittingStrategy @@ -106,43 +109,68 @@ def test_energy_scaling_and_offset(): ) dataset.prepare_data() dataset.setup() + # get methane input + methane = next(iter(dataset.train_dataloader(shuffle=False))).nnp_input + # load dataset statistic + import toml + + dataset_statistic = toml.load(dataset.dataset_statistic_filename) # -------------------------------# - # initialize model + # initialize model without any postprocessing # -------------------------------# - config = load_configs("ani2x_without_ase", "qm9") + config = load_configs("ani2x", "qm9") - import toml + torch.manual_seed(42) + model = ANI2x( + **config["potential"]["core_parameter"], + postprocessing_parameter=config["potential"]["postprocessing_parameter"], + ) + output_no_postprocessing = model(methane) + # -------------------------------# + # Scale output - dataset_statistic = toml.load(dataset.dataset_statistic_filename) torch.manual_seed(42) - model = ANI2x(**potential_parameter, dataset_statistic=dataset_statistic) + model = ANI2x( + **config["potential"]["core_parameter"], + postprocessing_parameter=config["potential"]["postprocessing_parameter"], + dataset_statistic=dataset_statistic, + ) + scaled_output = model(methane) - # -------------------------------# - # Test that we can add the reference energy correctly - # get methane input - methane = next(iter(dataset.train_dataloader())).nnp_input + # make sure that the scaled output equals the unscaled output + from openff.units import unit - # let's predict without any further postprocessing - output_no_postprocessing = model(methane) + mean = unit.Quantity( + dataset_statistic["training_dataset_statistics"]["per_atom_energy_mean"] + ).m + stddev = unit.Quantity( + dataset_statistic["training_dataset_statistics"]["per_atom_energy_stddev"] + ).m - # let's add self energies - import toml + compare_to = output_no_postprocessing["per_atom_energy"] * stddev + mean + assert torch.allclose(scaled_output["per_atom_energy"], compare_to) - # load dataset statistic - dataset_statistic = toml.load(dataset.dataset_statistic_filename) - # load potential parameter - config = load_configs("ani2x", "qm9") - potential_parameter = config["potential"].get("potential_parameter", {}) - torch.manual_seed(42) - model = ANI2x(**potential_parameter, dataset_statistic=dataset_statistic) - output_with_ase = model(methane) + # -------------------------------# + # Calculate atomic self energies - # make sure that the raw prediction is the same - import torch + # modify postprocessing parameters + config["potential"]["postprocessing_parameter"][ + "general_postprocessing_operation" + ] = {"calculate_molecular_self_energy": True} - assert torch.isclose(output_no_postprocessing["E"], output_with_ase["E"]) + model = ANI2x( + **config["potential"]["core_parameter"], + postprocessing_parameter=config["potential"]["postprocessing_parameter"], + dataset_statistic=dataset_statistic, + ) - assert torch.isclose(output_with_ase["mse"], torch.tensor([-707050.0])) + output_with_molecular_self_energies = model(methane) + + # make sure that the raw prediction is the same + assert torch.isclose( + output_with_molecular_self_energies["per_molecule_self_energy"], + torch.tensor([-104620.5859]), + ) @pytest.mark.parametrize("model_name", _Implemented_NNPs.get_all_neural_network_names()) @@ -208,7 +236,7 @@ def test_dataset_statistic(model_name): # extract value to compare against toml_E_i_mean = unit.Quantity( - dataset_statistic["atomic_energies_stats"]["E_i_mean"] + dataset_statistic["training_dataset_statistics"]["per_atom_energy_mean"] ).m # set up training model @@ -223,13 +251,13 @@ def test_dataset_statistic(model_name): import numpy as np print(training_adapter.model.postprocessing.dataset_statistic) - # check that the E_i_mean is the same than in the dataset statistics + # check that the per_atom_energy_mean is the same than in the dataset statistics assert np.isclose( toml_E_i_mean, unit.Quantity( training_adapter.model.postprocessing.dataset_statistic[ - "atomic_energies_stats" - ]["E_i_mean"] + "training_dataset_statistics" + ]["per_atom_energy_mean"] ).m, ) @@ -250,7 +278,9 @@ def test_dataset_statistic(model_name): assert np.isclose( toml_E_i_mean, unit.Quantity( - model.postprocessing.dataset_statistic["atomic_energies_stats"]["E_i_mean"] + model.postprocessing.dataset_statistic["training_dataset_statistics"][ + "per_atom_energy_mean" + ] ).m, ) diff --git a/modelforge/tests/test_sake.py b/modelforge/tests/test_sake.py index 94c2709a..60b2c2ff 100644 --- a/modelforge/tests/test_sake.py +++ b/modelforge/tests/test_sake.py @@ -587,11 +587,13 @@ def test_model_invariance(single_batch_with_batchsize_1): from modelforge.tests.test_models import load_configs - config = load_configs(f"sake_without_ase", "qm9") - # Extract parameters - potential_parameter = config["potential"].get("potential_parameter", {}) + config = load_configs(f"sake", "qm9") - model = SAKE(**potential_parameter) + # initialize model + model = SAKE( + **config["potential"]["core_parameter"], + postprocessing_parameter=config["potential"]["postprocessing_parameter"], + ) # get methane input methane = single_batch_with_batchsize_1.nnp_input @@ -602,4 +604,6 @@ def test_model_invariance(single_batch_with_batchsize_1): reference_out = model(methane) perturbed_out = model(perturbed_methane_input) - assert torch.allclose(reference_out["E"], perturbed_out["E"]) + assert torch.allclose( + reference_out["per_molecule_energy"], perturbed_out["per_molecule_energy"] + ) diff --git a/modelforge/tests/test_schnet.py b/modelforge/tests/test_schnet.py index 04353bf7..4267c533 100644 --- a/modelforge/tests/test_schnet.py +++ b/modelforge/tests/test_schnet.py @@ -165,13 +165,15 @@ def test_compare_forward(): from modelforge.tests.test_models import load_configs - torch.manual_seed(1234) # load default parameters config = load_configs(f"schnet", "qm9") # override default parameters config["potential"]["core_parameter"]["number_of_atom_features"] = 12 config["potential"]["core_parameter"]["number_of_radial_basis_functions"] = 5 + config["potential"]["core_parameter"]["number_of_filters"] = 12 + + torch.manual_seed(1234) # initialize model schnet = SchNet( @@ -277,8 +279,7 @@ def test_compare_forward(): # ---------------------------------------- # # test forward pass # ---------------------------------------- # - - # Check full pass + # reset torch.manual_seed(1234) for i in range(3): schnet.core_module.interaction_modules[i].intput_to_feature.reset_parameters() From 03a5930ba79205d8f0dbffb15377ac9499334b77 Mon Sep 17 00:00:00 2001 From: wiederm Date: Thu, 11 Jul 2024 23:06:58 +0200 Subject: [PATCH 49/78] error calculation tests --- modelforge/tests/conftest.py | 8 + .../tests/data/training_defaults/default.toml | 6 +- modelforge/tests/test_dataset.py | 10 +- modelforge/tests/test_training.py | 101 ++++------ modelforge/tests/test_utils.py | 15 +- modelforge/train/training.py | 175 ++++++++++-------- 6 files changed, 160 insertions(+), 155 deletions(-) diff --git a/modelforge/tests/conftest.py b/modelforge/tests/conftest.py index 7db41d0f..b42b03b3 100644 --- a/modelforge/tests/conftest.py +++ b/modelforge/tests/conftest.py @@ -122,6 +122,14 @@ def single_batch_with_batchsize_2_with_force(): return single_batch(batch_size=2, dataset_name="PHALKETHOH") +@pytest.fixture(scope="session") +def single_batch_with_batchsize_16_with_force(): + """ + Utility fixture to create a single batch of data for testing. + """ + return single_batch(batch_size=16, dataset_name="PHALKETHOH") + + def initialize_dataset( dataset_name: str, local_cache_dir: str, diff --git a/modelforge/tests/data/training_defaults/default.toml b/modelforge/tests/data/training_defaults/default.toml index 7dbaa5dd..87e14eaa 100644 --- a/modelforge/tests/data/training_defaults/default.toml +++ b/modelforge/tests/data/training_defaults/default.toml @@ -24,10 +24,10 @@ monitor = "val/per_molecule_energy/rmse" interval = "epoch" [training.training_parameter.loss_parameter] -loss_property = ['per_molecule_energy_error', 'per_atom_force_error'] # use . +loss_property = ['per_molecule_energy', 'per_atom_force'] # use . [training.training_parameter.loss_parameter.weight] -per_molecule_energy_error = 0.999 #NOTE: reciproce units -per_atom_force_error = 0.001 +per_molecule_energy = 0.999 #NOTE: reciproce units +per_atom_force = 0.001 [training.early_stopping] verbose = true diff --git a/modelforge/tests/test_dataset.py b/modelforge/tests/test_dataset.py index ff8bb2f7..d5cdb0b7 100644 --- a/modelforge/tests/test_dataset.py +++ b/modelforge/tests/test_dataset.py @@ -464,22 +464,18 @@ def test_dataset_neighborlist(model_name, single_batch_with_batchsize_64): nnp_input = single_batch_with_batchsize_64.nnp_input # test that the neighborlist is correctly generated - # cast input and model to torch.float64 - # read default parameters from modelforge.tests.test_models import load_configs # read default parameters - config = load_configs(f"{model_name}_without_ase", "qm9") + config = load_configs(f"{model_name}", "qm9") # Extract parameters - potential_parameter = config["potential"].get("potential_parameter", {}) from modelforge.potential.models import NeuralNetworkPotentialFactory - + # initialize model model = NeuralNetworkPotentialFactory.generate_model( use="inference", - model_type=model_name, simulation_environment="PyTorch", - model_parameter=potential_parameter, + model_parameter=config["potential"], ) model(nnp_input) diff --git a/modelforge/tests/test_training.py b/modelforge/tests/test_training.py index 04018d33..bb6be277 100644 --- a/modelforge/tests/test_training.py +++ b/modelforge/tests/test_training.py @@ -35,19 +35,7 @@ def load_configs(model_name: str, dataset_name: str): @pytest.mark.skipif(ON_MACOS, reason="Skipping this test on MacOS GitHub Actions") @pytest.mark.parametrize("model_name", _Implemented_NNPs.get_all_neural_network_names()) @pytest.mark.parametrize("dataset_name", ["QM9"]) -@pytest.mark.parametrize( - "loss_type", - [ - { - "loss_type": "EnergyAndForceLoss", - "include_force": True, - "force_weight": 0.99, - "energy_weight": 0.01, - }, - {"loss_type": "EnergyAndForceLoss"}, - ], -) -def test_train_with_lightning(model_name, dataset_name, loss_type): +def test_train_with_lightning(model_name, dataset_name): """ Test the forward pass for a given model and dataset. """ @@ -63,8 +51,6 @@ def test_train_with_lightning(model_name, dataset_name, loss_type): dataset_config = config["dataset"] runtime_config = config["runtime"] - # set loss type - training_config["training_parameter"]["loss_parameter"] = loss_type # perform training trainer = perform_training( potential_config=potential_config, @@ -87,59 +73,46 @@ def test_train_with_lightning(model_name, dataset_name, loss_type): import torch -def test_loss_fkt(single_batch_with_batchsize_2_with_force): - from torch_scatter import scatter_sum - - batch = single_batch_with_batchsize_2_with_force - E_true = batch.metadata.E - F_true = batch.metadata.F - F_predict = torch.randn_like(F_true) - E_predict = torch.randn_like(E_true) - - F_scaling = torch.tensor([1.0]) - - F_error_per_atom = torch.norm(F_true - F_predict, dim=1) ** 2 - F_error_per_molecule = scatter_sum( - F_error_per_atom, batch.nnp_input.atomic_subsystem_indices.long(), 0 +def test_error_calculation(single_batch_with_batchsize_16_with_force): + # test the different Loss classes + from modelforge.train.training import ( + FromPerAtomToPerMoleculeError, + PerMoleculeError, ) + from torch_scatter import scatter_sum - scale = F_scaling / (3 * batch.metadata.atomic_subsystem_counts) - F_per_mol_scaled = F_error_per_molecule / scale - - -@pytest.fixture -def _initialize_predict_target_dictionary(): - # initalize the test system - predict_target = {} - predict_target["E_predict"] = torch.tensor([[1.0], [2.0], [3.0]]) - predict_target["E_true"] = torch.tensor([[1.0], [-2.0], [3.0]]) - predict_target["F_predict"] = torch.tensor( - [[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]] + # generate data + data = single_batch_with_batchsize_16_with_force + true_E = data.metadata.E + true_F = data.metadata.F + + # make predictions + predicted_E = true_E + torch.rand_like(true_E) * 10 + predicted_F = true_F + torch.rand_like(true_F) * 10 + # test energy error + error = PerMoleculeError() + E_error = error(predicted_E, true_E, data) + + # test energy error (mean squared error scaled by number of atoms in the molecule) + reference_E_error = torch.mean( + ((predicted_E - true_E) ** 2) / data.metadata.atomic_subsystem_counts ) - predict_target["F_true"] = torch.tensor( - [[1.0, -2.0, -3.0], [1.0, -2.0, -3.0], [1.0, -2.0, -3.0]] + assert torch.allclose(E_error, reference_E_error) + + # test force error + error = FromPerAtomToPerMoleculeError() + F_error = error(predicted_F, true_F, data) + + # test force error (mean squared error scaled by number of atoms in the molecule) + reference_F_error = torch.mean( + scatter_sum( + torch.norm(predicted_F - true_F, dim=1) ** 2, + data.nnp_input.atomic_subsystem_indices.long(), + 0, + ) + / data.metadata.atomic_subsystem_counts ) - return predict_target - - -def test_energy_loss_only(_initialize_predict_target_dictionary): - # test the different Loss classes - from modelforge.train.training import EnergyLoss - - # initialize loss - loss_calculator = EnergyLoss() - predict_target = _initialize_predict_target_dictionary - # this loss calculates validation and training error as MSE and test error as RMSE - mse_expected_loss = torch.mean( - (predict_target["E_predict"] - predict_target["E_true"]) ** 2 - ) - - # test loss class - # make sure that train loss is MSE as expected - loss = loss_calculator.calculate_loss(predict_target, None) - assert torch.isclose( - mse_expected_loss, loss["combined_loss"] - ), f"Expected {mse_expected_loss.item()} but got {loss['combined_loss'].item()}" + assert torch.allclose(F_error, reference_F_error) @pytest.mark.skipif( diff --git a/modelforge/tests/test_utils.py b/modelforge/tests/test_utils.py index 393c4a03..2bb80bd0 100644 --- a/modelforge/tests/test_utils.py +++ b/modelforge/tests/test_utils.py @@ -438,11 +438,16 @@ def test_energy_readout(): # the input for the EnergyReadout module is vector (E_i) that will be scatter_added, and # a second tensor supplying the indixes for the summation - E_i = torch.tensor([3, 3, 1, 1, 1, 1, 1, 1], dtype=torch.float32) - atomic_subsystem_indices = torch.tensor([0, 0, 1, 1, 1, 1, 1, 1]) - - energy_readout = FromAtomToMoleculeReduction() - E = energy_readout(E_i, atomic_subsystem_indices) + r = { + "per_atom_energy": torch.tensor([3, 3, 1, 1, 1, 1, 1, 1], dtype=torch.float32), + "atomic_subsystem_index": torch.tensor([0, 0, 1, 1, 1, 1, 1, 1]), + } + energy_readout = FromAtomToMoleculeReduction( + per_atom_property_name="per_atom_energy", + index_name="atomic_subsystem_index", + output_name="per_molecule_energy", + ) + E = energy_readout(r)["per_molecule_energy"] # check that output has length of total number of molecules in batch assert E.size() == torch.Size( diff --git a/modelforge/train/training.py b/modelforge/train/training.py index 1b8328b6..0935bd2d 100644 --- a/modelforge/train/training.py +++ b/modelforge/train/training.py @@ -57,11 +57,11 @@ def forward( per_molecule_squared_error = scatter_sum( per_atom_squared_error, batch.nnp_input.atomic_subsystem_indices.long(), 0 ) + # divide by nnumber of atoms per_molecule_square_error_scaled = ( per_molecule_squared_error / batch.metadata.atomic_subsystem_counts ) - # divide by nnumber of atoms - return per_molecule_square_error_scaled + return torch.mean(per_molecule_square_error_scaled) class PerMoleculeError(nn.Module): @@ -104,7 +104,8 @@ def forward( ) # average - return torch.mean(per_molecule_square_error_scaled) + per_molecule_average = torch.mean(per_molecule_square_error_scaled) + return per_molecule_average class Loss(nn.Module): @@ -121,7 +122,7 @@ class Loss(nn.Module): Module dictionary containing the loss functions for each property. """ - _SUPPORTED_PROPERTIES = ["per_molecule_energy_error", "per_atom_force_error"] + _SUPPORTED_PROPERTIES = ["per_molecule_energy", "per_atom_force"] def __init__(self, loss_porperty: List[str], weight: Dict[str, float]): """ @@ -140,7 +141,6 @@ def __init__(self, loss_porperty: List[str], weight: Dict[str, float]): If an unsupported loss type is specified. """ super().__init__() - from torch.nn import ModuleDict self.loss_property = loss_porperty @@ -150,7 +150,7 @@ def __init__(self, loss_porperty: List[str], weight: Dict[str, float]): for prop, w in weight.items(): if prop in self._SUPPORTED_PROPERTIES: - if prop == "per_atom_force_error": + if prop == "per_atom_force": self.loss[prop] = FromPerAtomToPerMoleculeError() else: self.loss[prop] = PerMoleculeError() @@ -174,15 +174,24 @@ def forward(self, predict_target: Dict[str, torch.Tensor], batch): torch.Tensor The combined loss for the specified properties. """ - - loss = torch.zeros_like(predict_target["E_true"]) - + loss = torch.tensor([0.0]) + # save the loss as a dictionary + r = {} + # iterate over loss properties for prop in self.loss_property: - loss += self.weight[prop] * self.loss[prop]( - predict_target[prop], predict_target[f"{prop}_true"], batch + # calculate loss per property + loss_ = self.weight[prop] * self.loss[prop]( + predict_target[f"{prop}_predict"], predict_target[f"{prop}_true"], batch ) + # add total loss + loss = loss + loss_ + # save loss + r[prop] = loss_ + + # add total loss to results dict and return + r["total_loss"] = loss - return loss + return r class LossFactory(object): @@ -201,7 +210,6 @@ def create_loss(loss_property: List[str], weight: Dict[str, float]) -> Type[Loss List of properties to include in the loss calculation. weight : Dict[str, float] Dictionary containing the weights for each property in the loss calculation. - Returns ------- Loss @@ -246,60 +254,66 @@ def __init__( """ from modelforge.potential import _Implemented_NNPs - from torchmetrics.regression import MeanSquaredError, MeanAbsoluteError - from torchmetrics import MetricCollection super().__init__() self.save_hyperparameters() - # Extracting and instantiating the model from parameters - model_name = model_parameter["model_name"] # Get requested model class + model_name = model_parameter["model_name"] nnp_class: Type = _Implemented_NNPs.get_neural_network_class(model_name) + # initialize model self.model = nnp_class( **model_parameter["core_parameter"], dataset_statistic=dataset_statistic, postprocessing_parameter=model_parameter["postprocessing_parameter"], ) + self.optimizer = optimizer self.learning_rate = lr self.lr_scheduler_config = lr_scheduler_config + + # register metrics + self._register_metrics(loss_parameter) + + # initialize loss self.loss = LossFactory.create_loss(**loss_parameter) - self.val_error = { - "energy": MetricCollection( - [MeanAbsoluteError(), MeanSquaredError(squared=False)] - ), - "force": MetricCollection( - [MeanAbsoluteError(), MeanSquaredError(squared=False)] - ), - } - self.train_error = { - "energy": MetricCollection( - [MeanAbsoluteError(), MeanSquaredError(squared=False)] - ), - "force": MetricCollection( - [MeanAbsoluteError(), MeanSquaredError(squared=False)] - ), - } - self.test_error = { - "energy": MetricCollection( - [MeanAbsoluteError(), MeanSquaredError(squared=False)] - ), - "force": MetricCollection( - [MeanAbsoluteError(), MeanSquaredError(squared=False)] - ), - } + def _register_metrics(self, loss_parameter: Dict[str, Any]): + from torchmetrics.regression import ( + MeanAbsoluteError, + MeanSquaredError, + ) + from torchmetrics import MetricCollection - # Register metrics - for phase, metrics in [ - ("val", self.val_error), - ("train", self.train_error), - ("test", self.test_error), - ]: - for property, collection in metrics.items(): - self.add_module(f"{phase}_{property}", collection) + # register logging + from torch.nn import ModuleDict + + self.test_error = ModuleDict( + { + prop: MetricCollection( + [MeanAbsoluteError(), MeanSquaredError(squared=False)] + ) + for prop in loss_parameter["loss_property"] + } + ) + self.val_error = ModuleDict( + { + prop: MetricCollection( + [MeanAbsoluteError(), MeanSquaredError(squared=False)] + ) + for prop in loss_parameter["loss_property"] + } + ) + + self.train_error = ModuleDict( + { + prop: MetricCollection( + [MeanAbsoluteError(), MeanSquaredError(squared=False)] + ) + for prop in loss_parameter["loss_property"] + } + ) def _get_forces( self, batch: "BatchData", energies: Dict[str, torch.Tensor] @@ -320,29 +334,32 @@ def _get_forces( The true forces from the dataset and the predicted forces by the model. """ nnp_input = batch.nnp_input - F_true = batch.metadata.F.to(torch.float32) + per_atom_force_true = batch.metadata.F.to(torch.float32) - if F_true.numel() < 1: + if per_atom_force_true.numel() < 1: raise RuntimeError("No force can be calculated.") - E_predict = energies["E_predict"] + per_molecule_energy_predict = energies["per_molecule_energy_predict"] # Ensure E_predict and nnp_input.positions require gradients and are on the same device - if not E_predict.requires_grad: - E_predict.requires_grad = True + if not per_molecule_energy_predict.requires_grad: + per_molecule_energy_predict.requires_grad = True if not nnp_input.positions.requires_grad: nnp_input.positions.requires_grad = True # Compute the gradient (forces) from the predicted energies grad = torch.autograd.grad( - E_predict.sum(), + per_molecule_energy_predict.sum(), nnp_input.positions, create_graph=False, retain_graph=True, )[0] - F_predict = -1 * grad # Forces are the negative gradient of energy + per_atom_force_predict = -1 * grad # Forces are the negative gradient of energy - return {"F_true": F_true, "F_predict": F_predict} + return { + "per_atom_force_true": per_atom_force_true, + "per_atom_force_predict": per_atom_force_predict, + } def _get_energies(self, batch: "BatchData") -> Dict[str, torch.Tensor]: """ @@ -359,13 +376,18 @@ def _get_energies(self, batch: "BatchData") -> Dict[str, torch.Tensor]: The true energies from the dataset and the predicted energies by the model. """ nnp_input = batch.nnp_input - E_true = batch.metadata.E.to(torch.float32).squeeze(1) - E_predict = self.model.forward(nnp_input)["E"] - assert E_true.shape == E_predict.shape, ( + per_molecule_energy_true = batch.metadata.E.to(torch.float32).squeeze(1) + per_molecule_energy_predict = self.model.forward(nnp_input)[ + "per_molecule_energy" + ] + assert per_molecule_energy_true.shape == per_molecule_energy_predict.shape, ( f"Shapes of true and predicted energies do not match: " - f"{E_true.shape} != {E_predict.shape}" + f"{per_molecule_energy_true.shape} != {per_molecule_energy_predict.shape}" ) - return {"E_true": E_true, "E_predict": E_predict} + return { + "per_molecule_energy_true": per_molecule_energy_true, + "per_molecule_energy_predict": per_molecule_energy_predict, + } def _get_predictions(self, batch: "BatchData") -> Dict[str, torch.Tensor]: """ @@ -399,7 +421,7 @@ def _log_metrics( self, error_dict: Dict[str, torchmetrics.MetricCollection], predict_target: Dict[str, torch.Tensor], - ) -> Dict[str, torch.Tensor]: + ): """ Updates the provided metric collections with the predicted and true targets. @@ -418,15 +440,15 @@ def _log_metrics( for property, metrics in error_dict.items(): for metric, error_log in metrics.items(): - if property == "energy": + if property == "per_molecule_energy": error_log( - predict_target["E_predict"].detach(), - predict_target["E_true"].detach(), + predict_target["per_molecule_energy_predict"].detach(), + predict_target["per_molecule_energy_true"].detach(), ) if property == "per_atom_force": error_log( - predict_target["F_predict"].detach(), - predict_target["F_true"].detach(), + predict_target["per_atom_force_predict"].detach(), + predict_target["per_atom_force_true"].detach(), ) def training_step(self, batch: "BatchData", batch_idx: int) -> torch.Tensor: @@ -450,10 +472,12 @@ def training_step(self, batch: "BatchData", batch_idx: int) -> torch.Tensor: predict_target = self._get_predictions(batch) # calculate the loss - loss_dict = self.loss_module(predict_target, batch) - # Update and log metrics + loss_dict = self.loss(predict_target, batch) + + # Update and log training error self._log_metrics(self.train_error, predict_target) - # log loss + + # log the loss for key, loss in loss_dict.items(): self.log( f"train/{key}", @@ -464,7 +488,7 @@ def training_step(self, batch: "BatchData", batch_idx: int) -> torch.Tensor: batch_size=1, ) # batch size is 1 because the mean of the batch is logged - return loss_dict["combined_loss"] + return loss_dict["total_loss"] @torch.enable_grad() def validation_step(self, batch: "BatchData", batch_idx: int) -> None: @@ -488,7 +512,7 @@ def validation_step(self, batch: "BatchData", batch_idx: int) -> None: # calculate energy and forces predict_target = self._get_predictions(batch) # calculate the loss - loss = self.loss_module.calculate_loss(predict_target, batch) + loss = self.loss(predict_target, batch) # log the loss self._log_metrics(self.val_error, predict_target) @@ -922,9 +946,8 @@ def perform_training( # Set up model model = NeuralNetworkPotentialFactory.generate_model( use="training", - model_type=model_name, dataset_statistic=dataset_statistic, - model_parameter=potential_config["potential_parameter"], + model_parameter=potential_config, training_parameter=training_config["training_parameter"], ) @@ -947,7 +970,7 @@ def perform_training( checkpoint_callback = ModelCheckpoint( save_top_k=2, - monitor="val/energy/rmse", + monitor="val/per_molecule_energy/rmse", filename="best_{potential_name}-{dataset_name}-{epoch:02d}-{val_loss:.2f}", ) From 90d4b664fecf79809bce3c61b5470ce9a22fdfed Mon Sep 17 00:00:00 2001 From: wiederm Date: Thu, 11 Jul 2024 23:19:51 +0200 Subject: [PATCH 50/78] update --- devtools/conda-envs/doc_env.yaml | 3 +-- devtools/conda-envs/test_env.yaml | 3 --- devtools/conda-envs/test_env_mac.yaml | 7 +------ modelforge/potential/models.py | 2 +- 4 files changed, 3 insertions(+), 12 deletions(-) diff --git a/devtools/conda-envs/doc_env.yaml b/devtools/conda-envs/doc_env.yaml index 660b005f..ab8ab75a 100644 --- a/devtools/conda-envs/doc_env.yaml +++ b/devtools/conda-envs/doc_env.yaml @@ -37,6 +37,5 @@ dependencies: - pip: - jax - pytorch2jax - - versioneer - - flax - git+https://github.com/ArnNag/sake.git@nanometer + - "ray[data,train,tune,serve]" diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index 213c6f16..ab8ab75a 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -37,8 +37,5 @@ dependencies: - pip: - jax - pytorch2jax - - versioneer - - flax - git+https://github.com/ArnNag/sake.git@nanometer - # - tensorflow - "ray[data,train,tune,serve]" diff --git a/devtools/conda-envs/test_env_mac.yaml b/devtools/conda-envs/test_env_mac.yaml index 8d5edb3c..079a1efc 100644 --- a/devtools/conda-envs/test_env_mac.yaml +++ b/devtools/conda-envs/test_env_mac.yaml @@ -36,11 +36,6 @@ dependencies: # pip installs - pip: - # - schnetpack>=2.0.0 - pytorch2jax - - versioneer - - flax - - torchviz - git+https://github.com/ArnNag/sake.git@nanometer - - tensorflow - - torchviz + - jax diff --git a/modelforge/potential/models.py b/modelforge/potential/models.py index 9c586f0a..62576d7a 100644 --- a/modelforge/potential/models.py +++ b/modelforge/potential/models.py @@ -895,7 +895,7 @@ def load_state_dict( # Prefix to remove prefix = "model." - excluded_keys = ["loss.per_molecule_energy_error", "loss.per_atom_force_error"] + excluded_keys = ["loss.per_molecule_energy", "loss.per_atom_force"] # Create a new dictionary without the prefix in the keys if prefix exists if any(key.startswith(prefix) for key in state_dict.keys()): From 62311c07f6f76cc6f9a13ddc49d3e85a50a4cd7f Mon Sep 17 00:00:00 2001 From: wiederm Date: Thu, 11 Jul 2024 23:26:10 +0200 Subject: [PATCH 51/78] add flax --- devtools/conda-envs/doc_env.yaml | 1 + devtools/conda-envs/test_env.yaml | 1 + devtools/conda-envs/test_env_mac.yaml | 1 + 3 files changed, 3 insertions(+) diff --git a/devtools/conda-envs/doc_env.yaml b/devtools/conda-envs/doc_env.yaml index ab8ab75a..20479aa1 100644 --- a/devtools/conda-envs/doc_env.yaml +++ b/devtools/conda-envs/doc_env.yaml @@ -36,6 +36,7 @@ dependencies: - pip: - jax + - flax - pytorch2jax - git+https://github.com/ArnNag/sake.git@nanometer - "ray[data,train,tune,serve]" diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index ab8ab75a..cce74a85 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -39,3 +39,4 @@ dependencies: - pytorch2jax - git+https://github.com/ArnNag/sake.git@nanometer - "ray[data,train,tune,serve]" + - flax diff --git a/devtools/conda-envs/test_env_mac.yaml b/devtools/conda-envs/test_env_mac.yaml index 079a1efc..d347993a 100644 --- a/devtools/conda-envs/test_env_mac.yaml +++ b/devtools/conda-envs/test_env_mac.yaml @@ -39,3 +39,4 @@ dependencies: - pytorch2jax - git+https://github.com/ArnNag/sake.git@nanometer - jax + - flax From 909238a70d2cc424c58a0cb9b93b9d214dae1f93 Mon Sep 17 00:00:00 2001 From: wiederm Date: Thu, 11 Jul 2024 23:34:00 +0200 Subject: [PATCH 52/78] update yaml --- devtools/conda-envs/test_env.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index cce74a85..6fff42d5 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -29,7 +29,6 @@ dependencies: - pytest-cov - codecov - requests - - versioneer # Docs - sphinx_rtd_theme @@ -40,3 +39,4 @@ dependencies: - git+https://github.com/ArnNag/sake.git@nanometer - "ray[data,train,tune,serve]" - flax + - versioneer From 629df009773afd1832b35cbd12c5b88aa94e8304 Mon Sep 17 00:00:00 2001 From: wiederm Date: Thu, 11 Jul 2024 23:42:05 +0200 Subject: [PATCH 53/78] update --- devtools/conda-envs/test_env.yaml | 2 +- modelforge/tests/test_sake.py | 28 ++++++++++++++-------------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index 6fff42d5..cce74a85 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -29,6 +29,7 @@ dependencies: - pytest-cov - codecov - requests + - versioneer # Docs - sphinx_rtd_theme @@ -39,4 +40,3 @@ dependencies: - git+https://github.com/ArnNag/sake.git@nanometer - "ray[data,train,tune,serve]" - flax - - versioneer diff --git a/modelforge/tests/test_sake.py b/modelforge/tests/test_sake.py index 60b2c2ff..a1c4c902 100644 --- a/modelforge/tests/test_sake.py +++ b/modelforge/tests/test_sake.py @@ -102,11 +102,14 @@ def test_layer_equivariance(h_atol, eq_atol, single_batch_with_batchsize_64): from modelforge.tests.test_models import load_configs - config = load_configs(f"sake_without_ase", "qm9") + config = load_configs(f"sake", "qm9") # Extract parameters - potential_parameter = config["potential"].get("potential_parameter", {}) - potential_parameter["number_of_atom_features"] = nr_atom_basis - sake = SAKE(**potential_parameter) + core_parameter = config["potential"]["core_parameter"] + core_parameter["number_of_atom_features"] = nr_atom_basis + sake = SAKE( + **core_parameter, + postprocessing_parameter=config["potential"]["postprocessing_parameter"], + ) # get methane input methane = single_batch_with_batchsize_64.nnp_input @@ -429,16 +432,13 @@ def test_model_against_reference(single_batch_with_batchsize_1): cutoff=cutoff, number_of_radial_basis_functions=50, epsilon=1e-8, - processing_operation=[], - readout_operation=[ - { - "step": "from_atom_to_molecule", - "mode": "sum", - "in": "per_atom_energy", - "index_key": "atomic_subsystem_indices", - "out": "E", + postprocessing_parameter={ + "per_atom_energy": { + "normalize": True, + "from_atom_to_molecule_reduction": True, + "keep_per_atom_property": True, } - ], + }, ) ref_sake = reference_sake.models.DenseSAKEModel( @@ -577,7 +577,7 @@ def test_model_against_reference(single_batch_with_batchsize_1): ref_out = ref_sake.apply(variables, h, x, mask=mask)[0].sum(-2) # ref_out is nan, so we can't compare it to the modelforge output - print(f"{mf_out['E']=}") + print(f"{mf_out['per_molecule_energy']=}") print(f"{ref_out=}") # assert torch.allclose(mf_out.E, torch.from_numpy(onp.array(ref_out[0]))) From 5cf4314567b2b591268648e059e208f566ead6ee Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Thu, 11 Jul 2024 15:35:21 -0700 Subject: [PATCH 54/78] Add comments for todos --- modelforge/potential/spookynet.py | 55 ++++++++++++++---------------- modelforge/potential/utils.py | 2 +- modelforge/tests/test_spookynet.py | 3 +- 3 files changed, 29 insertions(+), 31 deletions(-) diff --git a/modelforge/potential/spookynet.py b/modelforge/potential/spookynet.py index 0f084b99..6efbc68e 100644 --- a/modelforge/potential/spookynet.py +++ b/modelforge/potential/spookynet.py @@ -204,7 +204,7 @@ def _forward(self, data: SpookyNetNeuralNetworkData) -> Dict[str, torch.Tensor]: data.pair_indices, representation["f_ij"], representation["f_cutoff"], - representation["p_orbital_ij"], + representation["dir_ij"], representation["d_orbital_ij"] ) f += y # accumulate module output to features @@ -230,7 +230,6 @@ def __init__( number_of_interaction_modules: int = 3, cutoff: unit.Quantity = 5 * unit.angstrom, number_of_filters: int = 32, - shared_interactions: bool = False, ) -> None: """ Initialize the SpookyNet network. @@ -256,7 +255,6 @@ def __init__( number_of_radial_basis_functions=number_of_radial_basis_functions, number_of_interaction_modules=number_of_interaction_modules, number_of_filters=number_of_filters, - shared_interactions=shared_interactions, ) self.only_unique_pairs = False # NOTE: for pairlist self.input_preparation = InputPreparation( @@ -303,18 +301,17 @@ def __init__( super().__init__() # cutoff - from modelforge.potential import CosineCutoff - - # radial symmetry function - from .utils import PhysNetRadialSymmetryFunction + from .utils import ExponentialBernsteinRadialBasisFunction, CosineCutoff - self.radial_symmetry_function_module = PhysNetRadialSymmetryFunction( + self.radial_symmetry_function_module = ExponentialBernsteinRadialBasisFunction( number_of_radial_basis_functions=number_of_radial_basis_functions, - max_distance=cutoff, + ini_alpha=1.0, # TODO: put the right number dtype=torch.float32, ) + self.cutoff_module = CosineCutoff(cutoff=cutoff) + def forward(self, d_ij: torch.Tensor, r_ij: torch.Tensor) -> dict[str, torch.Tensor]: """ Forward pass of the representation module. @@ -336,15 +333,15 @@ def forward(self, d_ij: torch.Tensor, r_ij: torch.Tensor) -> dict[str, torch.Ten sqrt3 = math.sqrt(3) sqrt3half = 0.5 * sqrt3 # short-range distances - p_orbital_ij = r_ij / d_ij.unsqueeze(-1) + dir_ij = r_ij / d_ij.unsqueeze(-1) d_orbital_ij = torch.stack( [ - sqrt3 * p_orbital_ij[:, 0] * p_orbital_ij[:, 1], # xy - sqrt3 * p_orbital_ij[:, 0] * p_orbital_ij[:, 2], # xz - sqrt3 * p_orbital_ij[:, 1] * p_orbital_ij[:, 2], # yz - 0.5 * (3 * p_orbital_ij[:, 2] * p_orbital_ij[:, 2] - 1.0), # z2 + sqrt3 * dir_ij[:, 0] * dir_ij[:, 1], # xy + sqrt3 * dir_ij[:, 0] * dir_ij[:, 2], # xz + sqrt3 * dir_ij[:, 1] * dir_ij[:, 2], # yz + 0.5 * (3 * dir_ij[:, 2] * dir_ij[:, 2] - 1.0), # z2 sqrt3half - * (p_orbital_ij[:, 0] * p_orbital_ij[:, 0] - p_orbital_ij[:, 1] * p_orbital_ij[:, 1]), # x2-y2 + * (dir_ij[:, 0] * dir_ij[:, 0] - dir_ij[:, 1] * dir_ij[:, 1]), # x2-y2 ], dim=-1, ) @@ -352,7 +349,7 @@ def forward(self, d_ij: torch.Tensor, r_ij: torch.Tensor) -> dict[str, torch.Ten f_ij_cutoff = self.cutoff_module(d_ij) filters = f_ij * f_ij_cutoff - return {"filters": filters, "p_orbital_ij": p_orbital_ij, "d_orbital_ij": d_orbital_ij} + return {"filters": filters, "dir_ij": dir_ij, "d_orbital_ij": d_orbital_ij} class Swish(nn.Module): @@ -590,7 +587,7 @@ def forward( self, x_tilde: torch.Tensor, f_ij_after_cutoff: torch.Tensor, - p_orbital_ij: torch.Tensor, + dir_ij: torch.Tensor, d_orbital_ij: torch.Tensor, idx_i: torch.Tensor, idx_j: torch.Tensor, @@ -604,7 +601,7 @@ def forward( Atomic feature vectors. rbf (FloatTensor [N, num_basis_functions]): Values of the radial basis functions for the pairwise distances. - p_orbital_ij (TODO): + dir_ij (TODO): TODO d_orbital_ij (TODO): TODO @@ -616,8 +613,8 @@ def forward( """ # interaction functions gs = self.radial_s(f_ij_after_cutoff) - gp = self.radial_p(f_ij_after_cutoff).unsqueeze(-2) * p_orbital_ij.unsqueeze(-1) - gd = self.radial_d(f_ij_after_cutoff).unsqueeze(-2) * d_orbital_ij.unsqueeze(-1) + gp = self.radial_p(f_ij_after_cutoff).unsqueeze(-2) * dir_ij.unsqueeze(-1) # TODO: replace with einsum + gd = self.radial_d(f_ij_after_cutoff).unsqueeze(-2) * d_orbital_ij.unsqueeze(-1) # TODO: replace with einsum # atom featurizations xx = self.resblock_x(x_tilde) xs = self.resblock_s(x_tilde) @@ -628,15 +625,15 @@ def forward( xp = xp[idx_j] # L=1 xd = xd[idx_j] # L=2 # sum over neighbors - pp = x_tilde.new_zeros(x_tilde.shape[0], p_orbital_ij.shape[-1], x_tilde.shape[-1]) + pp = x_tilde.new_zeros(x_tilde.shape[0], dir_ij.shape[-1], x_tilde.shape[-1]) dd = x_tilde.new_zeros(x_tilde.shape[0], d_orbital_ij.shape[-1], x_tilde.shape[-1]) - s = xx.index_add(0, idx_i, gs * xs) # L=0 - p = pp.index_add_(0, idx_i, gp * xp.unsqueeze(-2)) # L=1 - d = dd.index_add_(0, idx_i, gd * xd.unsqueeze(-2)) # L=2 + s = xx.index_add(0, idx_i, gs * xs) # L=0 # TODO: replace with einsum + p = pp.index_add_(0, idx_i, gp * xp.unsqueeze(-2)) # L=1 # TODO: replace with einsum + d = dd.index_add_(0, idx_i, gd * xd.unsqueeze(-2)) # L=2 # TODO: replace with einsum # project tensorial features to scalars pa, pb = torch.split(self.projection_p(p), p.shape[-1], dim=-1) da, db = torch.split(self.projection_d(d), d.shape[-1], dim=-1) - return self.resblock(s + (pa * pb).sum(-2) + (da * db).sum(-2)) + return self.resblock(s + (pa * pb).sum(-2) + (da * db).sum(-2)) # TODO: replace with einsum class SpookyNetAttention(nn.Module): @@ -878,8 +875,8 @@ def forward( pairlist: torch.Tensor, # shape [n_pairs, 2] f_ij: torch.Tensor, # shape [n_pairs, 1, number_of_radial_basis_functions] TODO: why the 1? f_ij_cutoff: torch.Tensor, # shape [n_pairs, 1] - p_orbital_ij: torch.Tensor, # shape [n_pairs, 1] - d_orbital_ij: torch.Tensor, # shape [n_pairs, 1] + dir_ij: torch.Tensor, # shape [n_pairs, 1] + d_orbital_ij: torch.Tensor, # shape [n_pairs, 1] ) -> Tuple[torch.Tensor, torch.Tensor]: """ Evaluate all modules in the block. @@ -892,7 +889,7 @@ def forward( Latent atomic feature vectors. rbf (FloatTensor [P, num_basis_functions]): Values of the radial basis functions for the pairwise distances. - p_orbital_ij (FloatTensor [P, 3]): + dir_ij (FloatTensor [P, 3]): Unit vectors pointing from atom i to atom j for all atomic pairs. d_orbital_ij (FloatTensor [P]): Distances between atom i and atom j for all atomic pairs. @@ -911,7 +908,7 @@ def forward( idx_i, idx_j = pairlist[0], pairlist[1] x_tilde = self.residual_pre(x) del x - l = self.local_interaction(x_tilde, f_ij * f_ij_cutoff, p_orbital_ij, d_orbital_ij, idx_i, idx_j) + l = self.local_interaction(x_tilde, f_ij * f_ij_cutoff, dir_ij, d_orbital_ij, idx_i, idx_j) n = self.nonlocal_interaction(x_tilde) x_updated = self.residual_post(x_tilde + l + n) del x_tilde diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index ff784d73..6b95e2f6 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -984,7 +984,7 @@ def __init__(self, number_of_radial_basis_functions, ini_alpha, dtype=torch.int6 trainable_prefactor=False, dtype=dtype, ) - self.alpha = ini_alpha + self.alpha = ini_alpha #TODO: should this be unitful? def nondimensionalize_distances(self, d_ij: torch.Tensor) -> torch.Tensor: return -( diff --git a/modelforge/tests/test_spookynet.py b/modelforge/tests/test_spookynet.py index 8dc2180e..7d24e73b 100644 --- a/modelforge/tests/test_spookynet.py +++ b/modelforge/tests/test_spookynet.py @@ -197,7 +197,8 @@ def test_spookynet_bernstein_polynomial_equivalence(): num_basis_functions = 3 ref_exp_bernstein_polynomials = RefExponentialBernsteinPolynomials(num_basis_functions, exp_weighting=True) - mf_exp_bernstein_polynomials = MfExponentialBernSteinPolynomials(num_basis_functions, ini_alpha=1.0) + mf_exp_bernstein_polynomials = MfExponentialBernSteinPolynomials(num_basis_functions, ini_alpha=1.0) # TODO: put + # the right number N = 5 r_angstrom = torch.rand((N, 1)) From f52220840e46ac085c9169fc7a5b694fe8cdfe8f Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Thu, 11 Jul 2024 15:52:42 -0700 Subject: [PATCH 55/78] Fix schnet test --- modelforge/tests/test_schnet.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modelforge/tests/test_schnet.py b/modelforge/tests/test_schnet.py index 4267c533..1fb3773c 100644 --- a/modelforge/tests/test_schnet.py +++ b/modelforge/tests/test_schnet.py @@ -57,7 +57,7 @@ def test_init(): def test_compare_representation(): # compare schnetpack RadialSymmetryFunction with modelforge RadialSymmetryFunction - from modelforge.potential.utils import SchnetRadialSymmetryFunction + from modelforge.potential.utils import SchnetRadialBasisFunction from openff.units import unit # Initialize the RBFs @@ -65,7 +65,7 @@ def test_compare_representation(): cutoff = unit.Quantity(5.2, unit.angstrom) start = unit.Quantity(0.8, unit.angstrom) - rbf_module = SchnetRadialSymmetryFunction( + rbf_module = SchnetRadialBasisFunction( number_of_radial_basis_functions=number_of_gaussians, max_distance=cutoff, min_distance=start, @@ -238,11 +238,11 @@ def test_compare_forward(): ) calculated_phi_ij = ( schnet.core_module.schnet_representation_module.radial_symmetry_function_module( - d_ij.unsqueeze(1) / 10 + d_ij / 10 ) ) # NOTE: converting to nm - assert torch.allclose(reference_phi_ij, calculated_phi_ij, atol=1e-3) + assert torch.allclose(reference_phi_ij.squeeze(1), calculated_phi_ij, atol=1e-3) # ---------------------------------------- # # test cutoff # ---------------------------------------- # From eda75b99026b5f3e05043ff7cf5a83150ab5a22e Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Thu, 11 Jul 2024 15:54:35 -0700 Subject: [PATCH 56/78] Fix test physnet compare representation --- modelforge/tests/test_physnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelforge/tests/test_physnet.py b/modelforge/tests/test_physnet.py index 3d5638a2..5e860c06 100644 --- a/modelforge/tests/test_physnet.py +++ b/modelforge/tests/test_physnet.py @@ -67,5 +67,5 @@ def test_compare_representation(): reference_rbf = provide_reference_for_test_physnet_test_rbf() D = np.array([[1.0394776], [3.375541]], dtype=np.float32) - calculated_rbf = rbf(torch.tensor(D / 10).squeeze()) + calculated_rbf = rbf(torch.tensor(D / 10)) assert np.allclose(np.flip(reference_rbf.squeeze(), axis=1), calculated_rbf.numpy()) From 3e0e4d7c275b205d1809c7508dad3c8dc8111d17 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Thu, 11 Jul 2024 16:23:22 -0700 Subject: [PATCH 57/78] Update SpookyNet for postprocessing --- modelforge/potential/spookynet.py | 194 +++++++++--------- .../data/potential_defaults/spookynet.toml | 18 ++ modelforge/tests/test_spookynet.py | 15 +- 3 files changed, 132 insertions(+), 95 deletions(-) create mode 100644 modelforge/tests/data/potential_defaults/spookynet.toml diff --git a/modelforge/potential/spookynet.py b/modelforge/potential/spookynet.py index 6efbc68e..b70aaeca 100644 --- a/modelforge/potential/spookynet.py +++ b/modelforge/potential/spookynet.py @@ -50,7 +50,7 @@ class SpookyNetNeuralNetworkData(NeuralNetworkData): Shape: [num_atoms, embedding_dim], where `embedding_dim` is the dimensionality of the embedding vectors. f_ij : Optional[torch.Tensor] A tensor representing the radial symmetry function expansion of distances between atom pairs, capturing the - local chemical environment. Shape: [num_pairs, num_features], where `num_features` is the dimensionality of + local chemical environment. Shape: [num_pairs, number_of_atom_features], where `number_of_atom_features` is the dimensionality of the radial symmetry function expansion. This field will be populated after initialization. f_cutoff : Optional[torch.Tensor] A tensor representing the cosine cutoff function applied to the radial symmetry function expansion, ensuring @@ -94,7 +94,7 @@ def __init__( number_of_atom_features: int = 64, number_of_radial_basis_functions: int = 20, number_of_interaction_modules: int = 3, - number_of_filters: int = 64, + number_of_residual_blocks: int = 7, cutoff: unit.Quantity = 5.0 * unit.angstrom, ) -> None: """ @@ -114,9 +114,8 @@ def __init__( from .utils import Dense, ShiftedSoftplus log.debug("Initializing SpookyNet model.") - super().__init__(cutoff) + super().__init__() self.number_of_atom_features = number_of_atom_features - self.number_of_filters = number_of_filters or self.number_of_atom_features self.number_of_radial_basis_functions = number_of_radial_basis_functions # embedding @@ -127,19 +126,24 @@ def __init__( # initialize representation block self.spookynet_representation_block = SpookyNetRepresentation(cutoff, number_of_radial_basis_functions) - # initialize the energy readout - from .processing import FromAtomToMoleculeReduction - - self.readout_module = FromAtomToMoleculeReduction() - # Intialize interaction blocks self.interaction_modules = nn.ModuleList( [ SpookyNetInteractionModule( - self.number_of_atom_features, - self.number_of_filters, - number_of_radial_basis_functions, - ) + number_of_atom_features=number_of_atom_features, + number_of_radial_basis_functions=number_of_radial_basis_functions, + num_residual_pre=number_of_residual_blocks, + num_residual_local_x=number_of_residual_blocks, + num_residual_local_s=number_of_residual_blocks, + num_residual_local_p=number_of_residual_blocks, + num_residual_local_d=number_of_residual_blocks, + num_residual_local=number_of_residual_blocks, + num_residual_nonlocal_q=number_of_residual_blocks, + num_residual_nonlocal_k=number_of_residual_blocks, + num_residual_nonlocal_v=number_of_residual_blocks, + num_residual_post=number_of_residual_blocks, + num_residual_output=number_of_residual_blocks, + ) for _ in range(number_of_interaction_modules) ] ) @@ -178,7 +182,7 @@ def _model_specific_input_preparation( return nnp_input - def _forward(self, data: SpookyNetNeuralNetworkData) -> Dict[str, torch.Tensor]: + def compute_properties(self, data: SpookyNetNeuralNetworkData) -> Dict[str, torch.Tensor]: """ Calculate the energy for a given input batch. @@ -224,12 +228,14 @@ def _forward(self, data: SpookyNetNeuralNetworkData) -> Dict[str, torch.Tensor]: class SpookyNet(BaseNetwork): def __init__( self, - max_Z: int = 101, - number_of_atom_features: int = 32, - number_of_radial_basis_functions: int = 20, - number_of_interaction_modules: int = 3, - cutoff: unit.Quantity = 5 * unit.angstrom, - number_of_filters: int = 32, + max_Z: int, + number_of_atom_features: int, + number_of_radial_basis_functions: int, + number_of_interaction_modules: int, + number_of_residual_blocks: int, + cutoff: unit.Quantity, + postprocessing_parameter: Dict[str, Dict[str, bool]], + dataset_statistic: Optional[Dict[str, float]] = None, ) -> None: """ Initialize the SpookyNet network. @@ -239,26 +245,31 @@ def __init__( Parameters ---------- - max_Z : int, default=100 + max_Z : int Maximum atomic number to be embedded. - number_of_atom_features : int, default=64 + number_of_atom_features : int Dimension of the embedding vectors for atomic numbers. - number_of_radial_basis_functions:int, default=16 - number_of_interaction_modules : int, default=2 - cutoff : openff.units.unit.Quantity, default=5*unit.angstrom + number_of_radial_basis_functions :int + number_of_interaction_modules : int + cutoff : openff.units.unit.Quantity The cutoff distance for interactions. """ - super().__init__() + super().__init__( + dataset_statistic=dataset_statistic, + postprocessing_parameter=postprocessing_parameter, + ) + from modelforge.utils.units import _convert + self.core_module = SpookyNetCore( max_Z=max_Z, number_of_atom_features=number_of_atom_features, number_of_radial_basis_functions=number_of_radial_basis_functions, number_of_interaction_modules=number_of_interaction_modules, - number_of_filters=number_of_filters, + number_of_residual_blocks=number_of_residual_blocks, ) self.only_unique_pairs = False # NOTE: for pairlist self.input_preparation = InputPreparation( - cutoff=cutoff, only_unique_pairs=self.only_unique_pairs + cutoff=_convert(cutoff), only_unique_pairs=self.only_unique_pairs ) def _config_prior(self): @@ -272,7 +283,6 @@ def _config_prior(self): "number_of_interaction_modules": tune.randint(1, 5), "cutoff": tune.uniform(5, 10), "number_of_radial_basis_functions": tune.randint(8, 32), - "number_of_filters": tune.randint(32, 128), "shared_interactions": tune.choice([True, False]), } prior.update(shared_config_prior()) @@ -361,7 +371,7 @@ class Swish(nn.Module): For beta -> inf: f(x) -> max(0, alpha*x) Arguments: - num_features (int): + number_of_atom_features (int): Dimensions of feature space. initial_alpha (float): Initial "scale" alpha of the "linear component". @@ -374,14 +384,14 @@ class Swish(nn.Module): """ def __init__( - self, num_features: int, initial_alpha: float = 1.0, initial_beta: float = 1.702 + self, number_of_atom_features: int, initial_alpha: float = 1.0, initial_beta: float = 1.702 ) -> None: """ Initializes the Swish class. """ super(Swish, self).__init__() self.initial_alpha = initial_alpha self.initial_beta = initial_beta - self.register_parameter("alpha", nn.Parameter(torch.Tensor(num_features))) - self.register_parameter("beta", nn.Parameter(torch.Tensor(num_features))) + self.register_parameter("alpha", nn.Parameter(torch.Tensor(number_of_atom_features))) + self.register_parameter("beta", nn.Parameter(torch.Tensor(number_of_atom_features))) self.reset_parameters() def reset_parameters(self) -> None: @@ -392,14 +402,14 @@ def reset_parameters(self) -> None: def forward(self, x: torch.Tensor) -> torch.Tensor: """ Evaluate activation function given the input features x. - num_features: Dimensions of feature space. + number_of_atom_features: Dimensions of feature space. Arguments: - x (FloatTensor [:, num_features]): + x (FloatTensor [:, number_of_atom_features]): Input features. Returns: - y (FloatTensor [:, num_features]): + y (FloatTensor [:, number_of_atom_features]): Activated features. """ return self.alpha * F.silu(self.beta * x) @@ -411,22 +421,22 @@ class SpookyNetResidual(nn.Module): mappings in deep residual networks.". Arguments: - num_features (int): + number_of_atom_features (int): Dimensions of feature space. """ def __init__( self, - num_features: int, + number_of_atom_features: int, bias: bool = True, ) -> None: """ Initializes the Residual class. """ super(SpookyNetResidual, self).__init__() # initialize attributes - self.activation1 = Swish(num_features) - self.linear1 = nn.Linear(num_features, num_features, bias=bias) - self.activation2 = Swish(num_features) - self.linear2 = nn.Linear(num_features, num_features, bias=bias) + self.activation1 = Swish(number_of_atom_features) + self.linear1 = nn.Linear(number_of_atom_features, number_of_atom_features, bias=bias) + self.activation2 = Swish(number_of_atom_features) + self.linear2 = nn.Linear(number_of_atom_features, number_of_atom_features, bias=bias) self.reset_parameters(bias) def reset_parameters(self, bias: bool = True) -> None: @@ -441,14 +451,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ Apply residual block to input atomic features. N: Number of atoms. - num_features: Dimensions of feature space. + number_of_atom_features: Dimensions of feature space. Arguments: - x (FloatTensor [N, num_features]): + x (FloatTensor [N, number_of_atom_features]): Input feature representations of atoms. Returns: - y (FloatTensor [N, num_features]): + y (FloatTensor [N, number_of_atom_features]): Output feature representations of atoms. """ y = self.activation1(x) @@ -463,7 +473,7 @@ class SpookyNetResidualStack(nn.Module): Stack of num_blocks pre-activation residual blocks evaluated in sequence. Arguments: - num_features (int): + number_of_atom_features (int): Dimensions of feature space. number_of_residual_blocks (int): Number of residual blocks to be stacked in sequence. @@ -471,7 +481,7 @@ class SpookyNetResidualStack(nn.Module): def __init__( self, - num_features: int, + number_of_atom_features: int, number_of_residual_blocks: int, bias: bool = True, ) -> None: @@ -479,7 +489,7 @@ def __init__( super(SpookyNetResidualStack, self).__init__() self.stack = nn.ModuleList( [ - SpookyNetResidual(num_features, bias) + SpookyNetResidual(number_of_atom_features, bias) for _ in range(number_of_residual_blocks) ] ) @@ -488,14 +498,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ Applies all residual blocks to input features in sequence. N: Number of inputs. - num_features: Dimensions of feature space. + number_of_atom_features: Dimensions of feature space. Arguments: - x (FloatTensor [N, num_features]): + x (FloatTensor [N, number_of_atom_features]): Input feature representations. Returns: - y (FloatTensor [N, num_features]): + y (FloatTensor [N, number_of_atom_features]): Output feature representations. """ for residual in self.stack: @@ -506,16 +516,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class SpookyNetResidualMLP(nn.Module): def __init__( self, - num_features: int, + number_of_atom_features: int, number_of_residual_blocks: int, bias: bool = True, ) -> None: super(SpookyNetResidualMLP, self).__init__() self.residual = SpookyNetResidualStack( - num_features, number_of_residual_blocks, bias=bias + number_of_atom_features, number_of_residual_blocks, bias=bias ) - self.activation = Swish(num_features) - self.linear = nn.Linear(num_features, num_features, bias=bias) + self.activation = Swish(number_of_atom_features) + self.linear = nn.Linear(number_of_atom_features, number_of_atom_features, bias=bias) self.reset_parameters(bias) def reset_parameters(self, bias: bool = True) -> None: @@ -533,9 +543,9 @@ class SpookyNetLocalInteraction(nn.Module): neighboring atoms (message-passing). Arguments: - num_features (int): + number_of_atom_features (int): Dimensions of feature space. - num_basis_functions (int): + number_of_radial_basis_functions (int): Number of radial basis functions. num_residual_x (int): TODO @@ -551,8 +561,8 @@ class SpookyNetLocalInteraction(nn.Module): def __init__( self, - num_features: int, - num_basis_functions: int, + number_of_atom_features: int, + number_of_radial_basis_functions: int, num_residual_x: int, num_residual_s: int, num_residual_p: int, @@ -561,17 +571,17 @@ def __init__( ) -> None: """ Initializes the LocalInteraction class. """ super(SpookyNetLocalInteraction, self).__init__() - self.radial_s = nn.Linear(num_basis_functions, num_features, bias=False) - self.radial_p = nn.Linear(num_basis_functions, num_features, bias=False) - self.radial_d = nn.Linear(num_basis_functions, num_features, bias=False) - self.resblock_x = SpookyNetResidualMLP(num_features, num_residual_x) - self.resblock_s = SpookyNetResidualMLP(num_features, num_residual_s) - self.resblock_p = SpookyNetResidualMLP(num_features, num_residual_p) - self.resblock_d = SpookyNetResidualMLP(num_features, num_residual_d) - self.projection_p = nn.Linear(num_features, 2 * num_features, bias=False) - self.projection_d = nn.Linear(num_features, 2 * num_features, bias=False) + self.radial_s = nn.Linear(number_of_radial_basis_functions, number_of_atom_features, bias=False) + self.radial_p = nn.Linear(number_of_radial_basis_functions, number_of_atom_features, bias=False) + self.radial_d = nn.Linear(number_of_radial_basis_functions, number_of_atom_features, bias=False) + self.resblock_x = SpookyNetResidualMLP(number_of_atom_features, num_residual_x) + self.resblock_s = SpookyNetResidualMLP(number_of_atom_features, num_residual_s) + self.resblock_p = SpookyNetResidualMLP(number_of_atom_features, num_residual_p) + self.resblock_d = SpookyNetResidualMLP(number_of_atom_features, num_residual_d) + self.projection_p = nn.Linear(number_of_atom_features, 2 * number_of_atom_features, bias=False) + self.projection_d = nn.Linear(number_of_atom_features, 2 * number_of_atom_features, bias=False) self.resblock = SpookyNetResidualMLP( - num_features, num_residual + number_of_atom_features, num_residual ) self.reset_parameters() @@ -597,9 +607,9 @@ def forward( N: Number of atoms. P: Number of atom pairs. - x (FloatTensor [N, num_features]): + x (FloatTensor [N, number_of_atom_features]): Atomic feature vectors. - rbf (FloatTensor [N, num_basis_functions]): + rbf (FloatTensor [N, number_of_radial_basis_functions]): Values of the radial basis functions for the pairwise distances. dir_ij (TODO): TODO @@ -737,7 +747,7 @@ class SpookyNetNonlocalInteraction(nn.Module): atoms. Arguments: - num_features (int): + number_of_atom_features (int): Dimensions of feature space. num_residual_q (int): Number of residual blocks for queries. @@ -749,7 +759,7 @@ class SpookyNetNonlocalInteraction(nn.Module): def __init__( self, - num_features: int, + number_of_atom_features: int, num_residual_q: int, num_residual_k: int, num_residual_v: int, @@ -757,15 +767,15 @@ def __init__( """ Initializes the NonlocalInteraction class. """ super(SpookyNetNonlocalInteraction, self).__init__() self.resblock_q = SpookyNetResidualMLP( - num_features, num_residual_q + number_of_atom_features, num_residual_q ) self.resblock_k = SpookyNetResidualMLP( - num_features, num_residual_k + number_of_atom_features, num_residual_k ) self.resblock_v = SpookyNetResidualMLP( - num_features, num_residual_v + number_of_atom_features, num_residual_v ) - self.attention = SpookyNetAttention(dim_qk=num_features, num_random_features=num_features) + self.attention = SpookyNetAttention(dim_qk=number_of_atom_features, num_random_features=number_of_atom_features) self.reset_parameters() def reset_parameters(self) -> None: @@ -780,7 +790,7 @@ def forward( Evaluate interaction block. N: Number of atoms. - x (FloatTensor [N, num_features]): + x (FloatTensor [N, number_of_atom_features]): Atomic feature vectors. """ q = self.resblock_q(x_tilde) # queries @@ -794,9 +804,9 @@ class SpookyNetInteractionModule(nn.Module): InteractionModule of SpookyNet, which computes a single iteration. Arguments: - num_features (int): + number_of_atom_features (int): Dimensions of feature space. - num_basis_functions (int): + number_of_radial_basis_functions (int): Number of radial basis functions. num_residual_pre (int): Number of residual blocks applied to atomic features before @@ -827,8 +837,8 @@ class SpookyNetInteractionModule(nn.Module): def __init__( self, - num_features: int, - num_basis_functions: int, + number_of_atom_features: int, + number_of_radial_basis_functions: int, num_residual_pre: int, num_residual_local_x: int, num_residual_local_s: int, @@ -845,8 +855,8 @@ def __init__( super(SpookyNetInteractionModule, self).__init__() # initialize modules self.local_interaction = SpookyNetLocalInteraction( - num_features=num_features, - num_basis_functions=num_basis_functions, + number_of_atom_features=number_of_atom_features, + number_of_radial_basis_functions=number_of_radial_basis_functions, num_residual_x=num_residual_local_x, num_residual_s=num_residual_local_s, num_residual_p=num_residual_local_p, @@ -854,15 +864,15 @@ def __init__( num_residual=num_residual_local, ) self.nonlocal_interaction = SpookyNetNonlocalInteraction( - num_features=num_features, + number_of_atom_features=number_of_atom_features, num_residual_q=num_residual_nonlocal_q, num_residual_k=num_residual_nonlocal_k, num_residual_v=num_residual_nonlocal_v, ) - self.residual_pre = SpookyNetResidualStack(num_features, num_residual_pre) - self.residual_post = SpookyNetResidualStack(num_features, num_residual_post) - self.resblock = SpookyNetResidualMLP(num_features, num_residual_output) + self.residual_pre = SpookyNetResidualStack(number_of_atom_features, num_residual_pre) + self.residual_post = SpookyNetResidualStack(number_of_atom_features, num_residual_post) + self.resblock = SpookyNetResidualMLP(number_of_atom_features, num_residual_output) self.reset_parameters() def reset_parameters(self) -> None: @@ -885,9 +895,9 @@ def forward( B: Batch size (number of different molecules). Arguments: - x (FloatTensor [N, num_features]): + x (FloatTensor [N, number_of_atom_features]): Latent atomic feature vectors. - rbf (FloatTensor [P, num_basis_functions]): + rbf (FloatTensor [P, number_of_radial_basis_functions]): Values of the radial basis functions for the pairwise distances. dir_ij (FloatTensor [P, 3]): Unit vectors pointing from atom i to atom j for all atomic pairs. @@ -899,9 +909,9 @@ def forward( idx_j (LongTensor [P]): Same as idx_i, but for atom j. Returns: - x (FloatTensor [N, num_features]): + x (FloatTensor [N, number_of_atom_features]): Updated latent atomic feature vectors. - y (FloatTensor [N, num_features]): + y (FloatTensor [N, number_of_atom_features]): Contribution to output atomic features (environment descriptors). """ diff --git a/modelforge/tests/data/potential_defaults/spookynet.toml b/modelforge/tests/data/potential_defaults/spookynet.toml new file mode 100644 index 00000000..a9071aa4 --- /dev/null +++ b/modelforge/tests/data/potential_defaults/spookynet.toml @@ -0,0 +1,18 @@ +[potential] +model_name = "SpookyNet" + +[potential.core_parameter] +max_Z = 101 +number_of_atom_features = 32 +number_of_radial_basis_functions = 20 +cutoff = "5.0 angstrom" +number_of_interaction_modules = 3 +number_of_residual_blocks = 7 + +[potential.postprocessing_parameter] +[potential.postprocessing_parameter.per_atom_energy] +normalize = true +from_atom_to_molecule_reduction = true +keep_per_atom_property = true +[potential.postprocessing_parameter.general_postprocessing_operation] +calculate_molecular_self_energy = true diff --git a/modelforge/tests/test_spookynet.py b/modelforge/tests/test_spookynet.py index 7d24e73b..f6caaf18 100644 --- a/modelforge/tests/test_spookynet.py +++ b/modelforge/tests/test_spookynet.py @@ -5,11 +5,20 @@ import pytest -def test_spookynet_init(): +def test_init(): """Test initialization of the SpookyNet model.""" + from modelforge.potential.spookynet import SpookyNet - spookynet = SpookyNet() - assert spookynet is not None, "SpookyNet model should be initialized." + from modelforge.tests.test_models import load_configs + + # load default parameters + config = load_configs(f"spookynet", "qm9") + # initialize model + spookynet = SpookyNet( + **config["potential"]["core_parameter"], + postprocessing_parameter=config["potential"]["postprocessing_parameter"], + ) + assert spookynet is not None, "Schnet model should be initialized." from openff.units import unit From c4d61fe61ee22542c3e89a7afb7583d8d69d5654 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Thu, 11 Jul 2024 17:23:17 -0700 Subject: [PATCH 58/78] Working through forward test --- modelforge/potential/spookynet.py | 20 ++++++------- modelforge/tests/test_schnet.py | 2 +- modelforge/tests/test_spookynet.py | 47 +++++++++++++++++++++++++----- 3 files changed, 49 insertions(+), 20 deletions(-) diff --git a/modelforge/potential/spookynet.py b/modelforge/potential/spookynet.py index b70aaeca..b4c56837 100644 --- a/modelforge/potential/spookynet.py +++ b/modelforge/potential/spookynet.py @@ -124,7 +124,7 @@ def __init__( self.embedding_module = Embedding(max_Z, number_of_atom_features) # initialize representation block - self.spookynet_representation_block = SpookyNetRepresentation(cutoff, number_of_radial_basis_functions) + self.spookynet_representation_module = SpookyNetRepresentation(cutoff, number_of_radial_basis_functions) # Intialize interaction blocks self.interaction_modules = nn.ModuleList( @@ -197,7 +197,7 @@ def compute_properties(self, data: SpookyNetNeuralNetworkData) -> Dict[str, torc """ # Compute the representation for each atom (transform to radial basis set, multiply by cutoff) - representation = self.spookynet_representation_module(data.d_ij) + representation = self.spookynet_representation_module(data.d_ij, data.r_ij) x = data.atomic_embedding f = x.new_zeros(x.size()) # initialize output features to zero @@ -206,8 +206,7 @@ def compute_properties(self, data: SpookyNetNeuralNetworkData) -> Dict[str, torc x, y = interaction( x, data.pair_indices, - representation["f_ij"], - representation["f_cutoff"], + representation["filters"], representation["dir_ij"], representation["d_orbital_ij"] ) @@ -357,7 +356,7 @@ def forward(self, d_ij: torch.Tensor, r_ij: torch.Tensor) -> dict[str, torch.Ten ) f_ij = self.radial_symmetry_function_module(d_ij) f_ij_cutoff = self.cutoff_module(d_ij) - filters = f_ij * f_ij_cutoff + filters = f_ij * f_ij_cutoff # TODO: replace with einsum return {"filters": filters, "dir_ij": dir_ij, "d_orbital_ij": d_orbital_ij} @@ -611,10 +610,10 @@ def forward( Atomic feature vectors. rbf (FloatTensor [N, number_of_radial_basis_functions]): Values of the radial basis functions for the pairwise distances. - dir_ij (TODO): - TODO + dir_ij (TODO:): + TODO: d_orbital_ij (TODO): - TODO + TODO: idx_i (LongTensor [P]): Index of atom i for all atomic pairs ij. Each pair must be specified as both ij and ji. @@ -883,8 +882,7 @@ def forward( self, x: torch.Tensor, pairlist: torch.Tensor, # shape [n_pairs, 2] - f_ij: torch.Tensor, # shape [n_pairs, 1, number_of_radial_basis_functions] TODO: why the 1? - f_ij_cutoff: torch.Tensor, # shape [n_pairs, 1] + filters: torch.Tensor, # shape [n_pairs, 1, number_of_radial_basis_functions] TODO: why the 1? dir_ij: torch.Tensor, # shape [n_pairs, 1] d_orbital_ij: torch.Tensor, # shape [n_pairs, 1] ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -918,7 +916,7 @@ def forward( idx_i, idx_j = pairlist[0], pairlist[1] x_tilde = self.residual_pre(x) del x - l = self.local_interaction(x_tilde, f_ij * f_ij_cutoff, dir_ij, d_orbital_ij, idx_i, idx_j) + l = self.local_interaction(x_tilde, filters, dir_ij, d_orbital_ij, idx_i, idx_j) n = self.nonlocal_interaction(x_tilde) x_updated = self.residual_post(x_tilde + l + n) del x_tilde diff --git a/modelforge/tests/test_schnet.py b/modelforge/tests/test_schnet.py index 1fb3773c..a8853e63 100644 --- a/modelforge/tests/test_schnet.py +++ b/modelforge/tests/test_schnet.py @@ -56,7 +56,7 @@ def test_init(): def test_compare_representation(): - # compare schnetpack RadialSymmetryFunction with modelforge RadialSymmetryFunction + # compare schnetpack RadialBasisFunction with modelforge RadialBasisFunction from modelforge.potential.utils import SchnetRadialBasisFunction from openff.units import unit diff --git a/modelforge/tests/test_spookynet.py b/modelforge/tests/test_spookynet.py index f6caaf18..957d7cfa 100644 --- a/modelforge/tests/test_spookynet.py +++ b/modelforge/tests/test_spookynet.py @@ -1,5 +1,8 @@ from modelforge.potential.spookynet import SpookyNet from spookynet import SpookyNet as RefSpookyNet +from modelforge.tests.precalculated_values import ( + setup_single_methane_input, +) import torch import pytest @@ -24,14 +27,42 @@ def test_init(): from openff.units import unit -@pytest.mark.parametrize( - "model_parameter", - ( - [64, 50, 20, unit.Quantity(5.0, unit.angstrom), 2], - [32, 60, 10, unit.Quantity(7.0, unit.angstrom), 1], - [128, 120, 64, unit.Quantity(5.0, unit.angstrom), 3], - ), -) +def test_forward(): + # ---------------------------------------- # + # test the implementation of the representation part of the PaiNN model + # ---------------------------------------- # + from modelforge.potential.spookynet import SpookyNet + + from modelforge.tests.test_models import load_configs + + # load default parameters + config = load_configs(f"spookynet", "qm9") + + # override default parameters + config["potential"]["core_parameter"]["number_of_atom_features"] = 12 + config["potential"]["core_parameter"]["number_of_radial_basis_functions"] = 5 + + torch.manual_seed(1234) + + # initialize model + spookynet = SpookyNet( + **config["potential"]["core_parameter"], + postprocessing_parameter=config["potential"]["postprocessing_parameter"], + ).double() + + input = setup_single_methane_input() + model_input = input["modelforge_methane_input"] + + + spookynet.input_preparation._input_checks(model_input) + + pairlist_output = spookynet.input_preparation.prepare_inputs(model_input) + prepared_input = spookynet.core_module._model_specific_input_preparation( + model_input, pairlist_output + ) + calculated_results = spookynet.core_module.forward(model_input, pairlist_output) + + def test_spookynet_forward(single_batch_with_batchsize_64, model_parameter): """ Test the forward pass of the SpookyNet model. From f7b132df8c1643b9df729b6ca54f1f97d5f9bc2c Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Thu, 11 Jul 2024 17:44:06 -0700 Subject: [PATCH 59/78] Fix bug that resulted in incorrect dir_ij --- modelforge/potential/spookynet.py | 25 ++++++++++++++++--------- modelforge/potential/utils.py | 3 --- modelforge/tests/test_spookynet.py | 1 + 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/modelforge/potential/spookynet.py b/modelforge/potential/spookynet.py index b4c56837..217d88fe 100644 --- a/modelforge/potential/spookynet.py +++ b/modelforge/potential/spookynet.py @@ -204,11 +204,11 @@ def compute_properties(self, data: SpookyNetNeuralNetworkData) -> Dict[str, torc # Iterate over interaction blocks to update features for interaction in self.interaction_modules: x, y = interaction( - x, - data.pair_indices, - representation["filters"], - representation["dir_ij"], - representation["d_orbital_ij"] + x=x, + pairlist=data.pair_indices, + filters=representation["filters"], + dir_ij=representation["dir_ij"], + d_orbital_ij=representation["d_orbital_ij"], ) f += y # accumulate module output to features @@ -328,9 +328,9 @@ def forward(self, d_ij: torch.Tensor, r_ij: torch.Tensor) -> dict[str, torch.Ten Parameters ---------- d_ij : torch.Tensor - pairwise distances between atoms, shape (n_pairs). + pairwise distances between atoms, shape [num_pairs, 1]. r_ij : torch.Tensor - pairwise displacements between atoms, shape (n_pairs, 3). + pairwise displacements between atoms, shape [num_pairs, 3]. Returns ------- @@ -342,7 +342,7 @@ def forward(self, d_ij: torch.Tensor, r_ij: torch.Tensor) -> dict[str, torch.Ten sqrt3 = math.sqrt(3) sqrt3half = 0.5 * sqrt3 # short-range distances - dir_ij = r_ij / d_ij.unsqueeze(-1) + dir_ij = r_ij / d_ij d_orbital_ij = torch.stack( [ sqrt3 * dir_ij[:, 0] * dir_ij[:, 1], # xy @@ -916,7 +916,14 @@ def forward( idx_i, idx_j = pairlist[0], pairlist[1] x_tilde = self.residual_pre(x) del x - l = self.local_interaction(x_tilde, filters, dir_ij, d_orbital_ij, idx_i, idx_j) + l = self.local_interaction( + x_tilde=x_tilde, + f_ij_after_cutoff=filters, + dir_ij=dir_ij, + d_orbital_ij=d_orbital_ij, + idx_i=idx_i, + idx_j=idx_j, + ) n = self.nonlocal_interaction(x_tilde) x_updated = self.residual_post(x_tilde + l + n) del x_tilde diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index f33195b9..188593fd 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -544,8 +544,6 @@ def forward(self, nondimensionalized_distances: torch.Tensor) -> torch.Tensor: rbf (FloatTensor [N, num_basis_functions]): Values of the radial basis functions for the distances r. """ - print(f"{nondimensionalized_distances.shape=}") - print(f"{self.number_of_radial_basis_functions=}") assert nondimensionalized_distances.ndim == 2 assert ( nondimensionalized_distances.shape[1] @@ -556,7 +554,6 @@ def forward(self, nondimensionalized_distances: torch.Tensor) -> torch.Tensor: + (self.n + 1) * nondimensionalized_distances + self.v * torch.log(-torch.expm1(nondimensionalized_distances)) ) - print(f"{self.logc.shape=}") return torch.exp(x) diff --git a/modelforge/tests/test_spookynet.py b/modelforge/tests/test_spookynet.py index 957d7cfa..6f42ef72 100644 --- a/modelforge/tests/test_spookynet.py +++ b/modelforge/tests/test_spookynet.py @@ -57,6 +57,7 @@ def test_forward(): spookynet.input_preparation._input_checks(model_input) pairlist_output = spookynet.input_preparation.prepare_inputs(model_input) + print(f"{pairlist_output.d_ij.shape=}") prepared_input = spookynet.core_module._model_specific_input_preparation( model_input, pairlist_output ) From 9a52c01b667bb920ff4d6c73432346e7b2451cc9 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Thu, 11 Jul 2024 17:53:09 -0700 Subject: [PATCH 60/78] Replace * with einsum --- modelforge/potential/spookynet.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/modelforge/potential/spookynet.py b/modelforge/potential/spookynet.py index 217d88fe..59d2e90e 100644 --- a/modelforge/potential/spookynet.py +++ b/modelforge/potential/spookynet.py @@ -622,8 +622,9 @@ def forward( """ # interaction functions gs = self.radial_s(f_ij_after_cutoff) - gp = self.radial_p(f_ij_after_cutoff).unsqueeze(-2) * dir_ij.unsqueeze(-1) # TODO: replace with einsum - gd = self.radial_d(f_ij_after_cutoff).unsqueeze(-2) * d_orbital_ij.unsqueeze(-1) # TODO: replace with einsum + # p: num_pairs, f: number_of_atomic_features, r: number_of_radial_basis_functions + gp = torch.einsum("pf,pr->prf", self.radial_p(f_ij_after_cutoff), dir_ij) + gd = torch.einsum("pf,pr->prf", self.radial_d(f_ij_after_cutoff), d_orbital_ij) # atom featurizations xx = self.resblock_x(x_tilde) xs = self.resblock_s(x_tilde) From 2ec3f3fcf107b0b7b781eecd4dfc27716f77b618 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Thu, 11 Jul 2024 18:49:10 -0700 Subject: [PATCH 61/78] Replace more operations with einsum --- modelforge/potential/spookynet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modelforge/potential/spookynet.py b/modelforge/potential/spookynet.py index 59d2e90e..b804c02f 100644 --- a/modelforge/potential/spookynet.py +++ b/modelforge/potential/spookynet.py @@ -643,7 +643,8 @@ def forward( # project tensorial features to scalars pa, pb = torch.split(self.projection_p(p), p.shape[-1], dim=-1) da, db = torch.split(self.projection_d(d), d.shape[-1], dim=-1) - return self.resblock(s + (pa * pb).sum(-2) + (da * db).sum(-2)) # TODO: replace with einsum + # r: number_of_radial_basis_functions, x: 3 (geometry axis), f: number_of_atom_features + return self.resblock(s + torch.einsum("rxf,rxf->rf", pa, pb) + torch.einsum("rxf,rxf->rf", da, db)) class SpookyNetAttention(nn.Module): From 5a32fe4abfab0746282c2f267d7cab07bf6a5043 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Thu, 11 Jul 2024 19:09:54 -0700 Subject: [PATCH 62/78] Replace one more operation with torch.einsum. Fix axis label for n --- modelforge/potential/spookynet.py | 11 ++++++----- modelforge/tests/test_spookynet.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/modelforge/potential/spookynet.py b/modelforge/potential/spookynet.py index b804c02f..e511a5bc 100644 --- a/modelforge/potential/spookynet.py +++ b/modelforge/potential/spookynet.py @@ -637,14 +637,15 @@ def forward( # sum over neighbors pp = x_tilde.new_zeros(x_tilde.shape[0], dir_ij.shape[-1], x_tilde.shape[-1]) dd = x_tilde.new_zeros(x_tilde.shape[0], d_orbital_ij.shape[-1], x_tilde.shape[-1]) - s = xx.index_add(0, idx_i, gs * xs) # L=0 # TODO: replace with einsum - p = pp.index_add_(0, idx_i, gp * xp.unsqueeze(-2)) # L=1 # TODO: replace with einsum - d = dd.index_add_(0, idx_i, gd * xd.unsqueeze(-2)) # L=2 # TODO: replace with einsum + s = xx.index_add(0, idx_i, torch.einsum("pf,pf->pf", gs, xs)) # L=0 + # p: num_pairs, x: 3 (geometry axis), f: number_of_atom_features + p = pp.index_add_(0, idx_i, torch.einsum("pxf,pf->pxf", gp, xp)) # L=1 + d = dd.index_add_(0, idx_i, torch.einsum("pxf,pf->pxf", gd, xd)) # L=2 # project tensorial features to scalars pa, pb = torch.split(self.projection_p(p), p.shape[-1], dim=-1) da, db = torch.split(self.projection_d(d), d.shape[-1], dim=-1) - # r: number_of_radial_basis_functions, x: 3 (geometry axis), f: number_of_atom_features - return self.resblock(s + torch.einsum("rxf,rxf->rf", pa, pb) + torch.einsum("rxf,rxf->rf", da, db)) + # n: number_of_atoms_in_system, x: 3 (geometry axis), f: number_of_atom_features + return self.resblock(s + torch.einsum("nxf,nxf->nf", pa, pb) + torch.einsum("nxf,nxf->nf", da, db)) class SpookyNetAttention(nn.Module): diff --git a/modelforge/tests/test_spookynet.py b/modelforge/tests/test_spookynet.py index 6f42ef72..205f7306 100644 --- a/modelforge/tests/test_spookynet.py +++ b/modelforge/tests/test_spookynet.py @@ -40,7 +40,7 @@ def test_forward(): # override default parameters config["potential"]["core_parameter"]["number_of_atom_features"] = 12 - config["potential"]["core_parameter"]["number_of_radial_basis_functions"] = 5 + config["potential"]["core_parameter"]["number_of_radial_basis_functions"] = 7 torch.manual_seed(1234) From 6ce39e44127639be834a51113213916b0b98d0da Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Thu, 11 Jul 2024 19:24:58 -0700 Subject: [PATCH 63/78] Add explicit broadcast. Reformat code. --- modelforge/potential/spookynet.py | 310 +++++++++++++++++------------- 1 file changed, 174 insertions(+), 136 deletions(-) diff --git a/modelforge/potential/spookynet.py b/modelforge/potential/spookynet.py index e511a5bc..47295d9c 100644 --- a/modelforge/potential/spookynet.py +++ b/modelforge/potential/spookynet.py @@ -89,13 +89,13 @@ class SpookyNetNeuralNetworkData(NeuralNetworkData): class SpookyNetCore(CoreNetwork): def __init__( - self, - max_Z: int = 100, - number_of_atom_features: int = 64, - number_of_radial_basis_functions: int = 20, - number_of_interaction_modules: int = 3, - number_of_residual_blocks: int = 7, - cutoff: unit.Quantity = 5.0 * unit.angstrom, + self, + max_Z: int = 100, + number_of_atom_features: int = 64, + number_of_radial_basis_functions: int = 20, + number_of_interaction_modules: int = 3, + number_of_residual_blocks: int = 7, + cutoff: unit.Quantity = 5.0 * unit.angstrom, ) -> None: """ Initialize the SpookyNet class. @@ -124,7 +124,9 @@ def __init__( self.embedding_module = Embedding(max_Z, number_of_atom_features) # initialize representation block - self.spookynet_representation_module = SpookyNetRepresentation(cutoff, number_of_radial_basis_functions) + self.spookynet_representation_module = SpookyNetRepresentation( + cutoff, number_of_radial_basis_functions + ) # Intialize interaction blocks self.interaction_modules = nn.ModuleList( @@ -143,7 +145,7 @@ def __init__( num_residual_nonlocal_v=number_of_residual_blocks, num_residual_post=number_of_residual_blocks, num_residual_output=number_of_residual_blocks, - ) + ) for _ in range(number_of_interaction_modules) ] ) @@ -162,7 +164,7 @@ def __init__( ) def _model_specific_input_preparation( - self, data: "NNPInput", pairlist_output: "PairListOutputs" + self, data: "NNPInput", pairlist_output: "PairListOutputs" ) -> SpookyNetNeuralNetworkData: number_of_atoms = data.atomic_numbers.shape[0] @@ -182,7 +184,9 @@ def _model_specific_input_preparation( return nnp_input - def compute_properties(self, data: SpookyNetNeuralNetworkData) -> Dict[str, torch.Tensor]: + def compute_properties( + self, data: SpookyNetNeuralNetworkData + ) -> Dict[str, torch.Tensor]: """ Calculate the energy for a given input batch. @@ -226,19 +230,19 @@ def compute_properties(self, data: SpookyNetNeuralNetworkData) -> Dict[str, torc class SpookyNet(BaseNetwork): def __init__( - self, - max_Z: int, - number_of_atom_features: int, - number_of_radial_basis_functions: int, - number_of_interaction_modules: int, - number_of_residual_blocks: int, - cutoff: unit.Quantity, - postprocessing_parameter: Dict[str, Dict[str, bool]], - dataset_statistic: Optional[Dict[str, float]] = None, + self, + max_Z: int, + number_of_atom_features: int, + number_of_radial_basis_functions: int, + number_of_interaction_modules: int, + number_of_residual_blocks: int, + cutoff: unit.Quantity, + postprocessing_parameter: Dict[str, Dict[str, bool]], + dataset_statistic: Optional[Dict[str, float]] = None, ) -> None: """ Initialize the SpookyNet network. - + Unke, O.T., Chmiela, S., Gastegger, M. et al. SpookyNet: Learning force fields with electronic degrees of freedom and nonlocal effects. Nat Commun 12, 7273 (2021). @@ -291,9 +295,9 @@ def _config_prior(self): class SpookyNetRepresentation(nn.Module): def __init__( - self, - cutoff: unit = 5 * unit.angstrom, - number_of_radial_basis_functions: int = 16, + self, + cutoff: unit = 5 * unit.angstrom, + number_of_radial_basis_functions: int = 16, ): """ Representation module for the PhysNet potential, handling the generation of @@ -321,7 +325,9 @@ def __init__( self.cutoff_module = CosineCutoff(cutoff=cutoff) - def forward(self, d_ij: torch.Tensor, r_ij: torch.Tensor) -> dict[str, torch.Tensor]: + def forward( + self, d_ij: torch.Tensor, r_ij: torch.Tensor + ) -> dict[str, torch.Tensor]: """ Forward pass of the representation module. @@ -356,7 +362,12 @@ def forward(self, d_ij: torch.Tensor, r_ij: torch.Tensor) -> dict[str, torch.Ten ) f_ij = self.radial_symmetry_function_module(d_ij) f_ij_cutoff = self.cutoff_module(d_ij) - filters = f_ij * f_ij_cutoff # TODO: replace with einsum + filters = ( + f_ij.broadcast_to( + len(d_ij), + self.radial_symmetry_function_module.radial_basis_function.number_of_radial_basis_functions, + ) + ) * f_ij_cutoff return {"filters": filters, "dir_ij": dir_ij, "d_orbital_ij": d_orbital_ij} @@ -383,18 +394,25 @@ class Swish(nn.Module): """ def __init__( - self, number_of_atom_features: int, initial_alpha: float = 1.0, initial_beta: float = 1.702 + self, + number_of_atom_features: int, + initial_alpha: float = 1.0, + initial_beta: float = 1.702, ) -> None: - """ Initializes the Swish class. """ + """Initializes the Swish class.""" super(Swish, self).__init__() self.initial_alpha = initial_alpha self.initial_beta = initial_beta - self.register_parameter("alpha", nn.Parameter(torch.Tensor(number_of_atom_features))) - self.register_parameter("beta", nn.Parameter(torch.Tensor(number_of_atom_features))) + self.register_parameter( + "alpha", nn.Parameter(torch.Tensor(number_of_atom_features)) + ) + self.register_parameter( + "beta", nn.Parameter(torch.Tensor(number_of_atom_features)) + ) self.reset_parameters() def reset_parameters(self) -> None: - """ Initialize parameters alpha and beta. """ + """Initialize parameters alpha and beta.""" nn.init.constant_(self.alpha, self.initial_alpha) nn.init.constant_(self.beta, self.initial_beta) @@ -425,21 +443,25 @@ class SpookyNetResidual(nn.Module): """ def __init__( - self, - number_of_atom_features: int, - bias: bool = True, + self, + number_of_atom_features: int, + bias: bool = True, ) -> None: - """ Initializes the Residual class. """ + """Initializes the Residual class.""" super(SpookyNetResidual, self).__init__() # initialize attributes self.activation1 = Swish(number_of_atom_features) - self.linear1 = nn.Linear(number_of_atom_features, number_of_atom_features, bias=bias) + self.linear1 = nn.Linear( + number_of_atom_features, number_of_atom_features, bias=bias + ) self.activation2 = Swish(number_of_atom_features) - self.linear2 = nn.Linear(number_of_atom_features, number_of_atom_features, bias=bias) + self.linear2 = nn.Linear( + number_of_atom_features, number_of_atom_features, bias=bias + ) self.reset_parameters(bias) def reset_parameters(self, bias: bool = True) -> None: - """ Initialize parameters to compute an identity mapping. """ + """Initialize parameters to compute an identity mapping.""" nn.init.orthogonal_(self.linear1.weight) nn.init.zeros_(self.linear2.weight) if bias: @@ -479,12 +501,12 @@ class SpookyNetResidualStack(nn.Module): """ def __init__( - self, - number_of_atom_features: int, - number_of_residual_blocks: int, - bias: bool = True, + self, + number_of_atom_features: int, + number_of_residual_blocks: int, + bias: bool = True, ) -> None: - """ Initializes the ResidualStack class. """ + """Initializes the ResidualStack class.""" super(SpookyNetResidualStack, self).__init__() self.stack = nn.ModuleList( [ @@ -514,17 +536,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class SpookyNetResidualMLP(nn.Module): def __init__( - self, - number_of_atom_features: int, - number_of_residual_blocks: int, - bias: bool = True, + self, + number_of_atom_features: int, + number_of_residual_blocks: int, + bias: bool = True, ) -> None: super(SpookyNetResidualMLP, self).__init__() self.residual = SpookyNetResidualStack( number_of_atom_features, number_of_residual_blocks, bias=bias ) self.activation = Swish(number_of_atom_features) - self.linear = nn.Linear(number_of_atom_features, number_of_atom_features, bias=bias) + self.linear = nn.Linear( + number_of_atom_features, number_of_atom_features, bias=bias + ) self.reset_parameters(bias) def reset_parameters(self, bias: bool = True) -> None: @@ -559,33 +583,41 @@ class SpookyNetLocalInteraction(nn.Module): """ def __init__( - self, - number_of_atom_features: int, - number_of_radial_basis_functions: int, - num_residual_x: int, - num_residual_s: int, - num_residual_p: int, - num_residual_d: int, - num_residual: int, + self, + number_of_atom_features: int, + number_of_radial_basis_functions: int, + num_residual_x: int, + num_residual_s: int, + num_residual_p: int, + num_residual_d: int, + num_residual: int, ) -> None: - """ Initializes the LocalInteraction class. """ + """Initializes the LocalInteraction class.""" super(SpookyNetLocalInteraction, self).__init__() - self.radial_s = nn.Linear(number_of_radial_basis_functions, number_of_atom_features, bias=False) - self.radial_p = nn.Linear(number_of_radial_basis_functions, number_of_atom_features, bias=False) - self.radial_d = nn.Linear(number_of_radial_basis_functions, number_of_atom_features, bias=False) + self.radial_s = nn.Linear( + number_of_radial_basis_functions, number_of_atom_features, bias=False + ) + self.radial_p = nn.Linear( + number_of_radial_basis_functions, number_of_atom_features, bias=False + ) + self.radial_d = nn.Linear( + number_of_radial_basis_functions, number_of_atom_features, bias=False + ) self.resblock_x = SpookyNetResidualMLP(number_of_atom_features, num_residual_x) self.resblock_s = SpookyNetResidualMLP(number_of_atom_features, num_residual_s) self.resblock_p = SpookyNetResidualMLP(number_of_atom_features, num_residual_p) self.resblock_d = SpookyNetResidualMLP(number_of_atom_features, num_residual_d) - self.projection_p = nn.Linear(number_of_atom_features, 2 * number_of_atom_features, bias=False) - self.projection_d = nn.Linear(number_of_atom_features, 2 * number_of_atom_features, bias=False) - self.resblock = SpookyNetResidualMLP( - number_of_atom_features, num_residual + self.projection_p = nn.Linear( + number_of_atom_features, 2 * number_of_atom_features, bias=False + ) + self.projection_d = nn.Linear( + number_of_atom_features, 2 * number_of_atom_features, bias=False ) + self.resblock = SpookyNetResidualMLP(number_of_atom_features, num_residual) self.reset_parameters() def reset_parameters(self) -> None: - """ Initialize parameters. """ + """Initialize parameters.""" nn.init.orthogonal_(self.radial_s.weight) nn.init.orthogonal_(self.radial_p.weight) nn.init.orthogonal_(self.radial_d.weight) @@ -593,13 +625,13 @@ def reset_parameters(self) -> None: nn.init.orthogonal_(self.projection_d.weight) def forward( - self, - x_tilde: torch.Tensor, - f_ij_after_cutoff: torch.Tensor, - dir_ij: torch.Tensor, - d_orbital_ij: torch.Tensor, - idx_i: torch.Tensor, - idx_j: torch.Tensor, + self, + x_tilde: torch.Tensor, + f_ij_after_cutoff: torch.Tensor, + dir_ij: torch.Tensor, + d_orbital_ij: torch.Tensor, + idx_i: torch.Tensor, + idx_j: torch.Tensor, ) -> torch.Tensor: """ Evaluate interaction block. @@ -636,8 +668,10 @@ def forward( xd = xd[idx_j] # L=2 # sum over neighbors pp = x_tilde.new_zeros(x_tilde.shape[0], dir_ij.shape[-1], x_tilde.shape[-1]) - dd = x_tilde.new_zeros(x_tilde.shape[0], d_orbital_ij.shape[-1], x_tilde.shape[-1]) - s = xx.index_add(0, idx_i, torch.einsum("pf,pf->pf", gs, xs)) # L=0 + dd = x_tilde.new_zeros( + x_tilde.shape[0], d_orbital_ij.shape[-1], x_tilde.shape[-1] + ) + s = xx.index_add(0, idx_i, torch.einsum("pf,pf->pf", gs, xs)) # L=0 # p: num_pairs, x: 3 (geometry axis), f: number_of_atom_features p = pp.index_add_(0, idx_i, torch.einsum("pxf,pf->pxf", gp, xp)) # L=1 d = dd.index_add_(0, idx_i, torch.einsum("pxf,pf->pxf", gd, xd)) # L=2 @@ -645,7 +679,11 @@ def forward( pa, pb = torch.split(self.projection_p(p), p.shape[-1], dim=-1) da, db = torch.split(self.projection_d(d), d.shape[-1], dim=-1) # n: number_of_atoms_in_system, x: 3 (geometry axis), f: number_of_atom_features - return self.resblock(s + torch.einsum("nxf,nxf->nf", pa, pb) + torch.einsum("nxf,nxf->nf", da, db)) + return self.resblock( + s + + torch.einsum("nxf,nxf->nf", pa, pb) + + torch.einsum("nxf,nxf->nf", da, db) + ) class SpookyNetAttention(nn.Module): @@ -661,10 +699,8 @@ class SpookyNetAttention(nn.Module): this is 0, the exact attention matrix is computed. """ - def __init__( - self, dim_qk: int, num_random_features: int - ) -> None: - """ Initializes the Attention class. """ + def __init__(self, dim_qk: int, num_random_features: int) -> None: + """Initializes the Attention class.""" super(SpookyNetAttention, self).__init__() self.num_random_features = num_random_features omega = self._omega(num_random_features, dim_qk) @@ -672,11 +708,11 @@ def __init__( self.reset_parameters() def reset_parameters(self) -> None: - """ For compatibility with other modules. """ + """For compatibility with other modules.""" pass def _omega(self, nrows: int, ncols: int) -> np.ndarray: - """ Return a (nrows x ncols) random feature matrix. """ + """Return a (nrows x ncols) random feature matrix.""" nblocks = int(nrows / ncols) blocks = [] for i in range(nblocks): @@ -694,16 +730,16 @@ def _omega(self, nrows: int, ncols: int) -> np.ndarray: return (norm * np.vstack(blocks)).T def _phi( - self, - X: torch.Tensor, - is_query: bool, - eps: float = 1e-4, + self, + X: torch.Tensor, + is_query: bool, + eps: float = 1e-4, ) -> torch.Tensor: - """ Normalize X and project into random feature space. """ + """Normalize X and project into random feature space.""" d = X.shape[-1] m = self.omega.shape[-1] - U = torch.matmul(X / d ** 0.25, self.omega) - h = torch.sum(X ** 2, dim=-1, keepdim=True) / (2 * d ** 0.5) # OLD + U = torch.matmul(X / d**0.25, self.omega) + h = torch.sum(X**2, dim=-1, keepdim=True) / (2 * d**0.5) # OLD # determine maximum (is subtracted to prevent numerical overflow) if is_query: maximum, _ = torch.max(U, dim=-1, keepdim=True) @@ -712,11 +748,11 @@ def _phi( return (torch.exp(U - h - maximum) + eps) / math.sqrt(m) def forward( - self, - Q: torch.Tensor, - K: torch.Tensor, - V: torch.Tensor, - eps: float = 1e-8, + self, + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + eps: float = 1e-8, ) -> torch.Tensor: """ Compute attention for the given query, key and value vectors. @@ -760,33 +796,29 @@ class SpookyNetNonlocalInteraction(nn.Module): """ def __init__( - self, - number_of_atom_features: int, - num_residual_q: int, - num_residual_k: int, - num_residual_v: int, + self, + number_of_atom_features: int, + num_residual_q: int, + num_residual_k: int, + num_residual_v: int, ) -> None: - """ Initializes the NonlocalInteraction class. """ + """Initializes the NonlocalInteraction class.""" super(SpookyNetNonlocalInteraction, self).__init__() - self.resblock_q = SpookyNetResidualMLP( - number_of_atom_features, num_residual_q + self.resblock_q = SpookyNetResidualMLP(number_of_atom_features, num_residual_q) + self.resblock_k = SpookyNetResidualMLP(number_of_atom_features, num_residual_k) + self.resblock_v = SpookyNetResidualMLP(number_of_atom_features, num_residual_v) + self.attention = SpookyNetAttention( + dim_qk=number_of_atom_features, num_random_features=number_of_atom_features ) - self.resblock_k = SpookyNetResidualMLP( - number_of_atom_features, num_residual_k - ) - self.resblock_v = SpookyNetResidualMLP( - number_of_atom_features, num_residual_v - ) - self.attention = SpookyNetAttention(dim_qk=number_of_atom_features, num_random_features=number_of_atom_features) self.reset_parameters() def reset_parameters(self) -> None: - """ For compatibility with other modules. """ + """For compatibility with other modules.""" pass def forward( - self, - x_tilde: torch.Tensor, + self, + x_tilde: torch.Tensor, ) -> torch.Tensor: """ Evaluate interaction block. @@ -838,22 +870,22 @@ class SpookyNetInteractionModule(nn.Module): """ def __init__( - self, - number_of_atom_features: int, - number_of_radial_basis_functions: int, - num_residual_pre: int, - num_residual_local_x: int, - num_residual_local_s: int, - num_residual_local_p: int, - num_residual_local_d: int, - num_residual_local: int, - num_residual_nonlocal_q: int, - num_residual_nonlocal_k: int, - num_residual_nonlocal_v: int, - num_residual_post: int, - num_residual_output: int, + self, + number_of_atom_features: int, + number_of_radial_basis_functions: int, + num_residual_pre: int, + num_residual_local_x: int, + num_residual_local_s: int, + num_residual_local_p: int, + num_residual_local_d: int, + num_residual_local: int, + num_residual_nonlocal_q: int, + num_residual_nonlocal_k: int, + num_residual_nonlocal_v: int, + num_residual_post: int, + num_residual_output: int, ) -> None: - """ Initializes the InteractionModule class. """ + """Initializes the InteractionModule class.""" super(SpookyNetInteractionModule, self).__init__() # initialize modules self.local_interaction = SpookyNetLocalInteraction( @@ -872,22 +904,28 @@ def __init__( num_residual_v=num_residual_nonlocal_v, ) - self.residual_pre = SpookyNetResidualStack(number_of_atom_features, num_residual_pre) - self.residual_post = SpookyNetResidualStack(number_of_atom_features, num_residual_post) - self.resblock = SpookyNetResidualMLP(number_of_atom_features, num_residual_output) + self.residual_pre = SpookyNetResidualStack( + number_of_atom_features, num_residual_pre + ) + self.residual_post = SpookyNetResidualStack( + number_of_atom_features, num_residual_post + ) + self.resblock = SpookyNetResidualMLP( + number_of_atom_features, num_residual_output + ) self.reset_parameters() def reset_parameters(self) -> None: - """ For compatibility with other modules. """ + """For compatibility with other modules.""" pass def forward( - self, - x: torch.Tensor, - pairlist: torch.Tensor, # shape [n_pairs, 2] - filters: torch.Tensor, # shape [n_pairs, 1, number_of_radial_basis_functions] TODO: why the 1? - dir_ij: torch.Tensor, # shape [n_pairs, 1] - d_orbital_ij: torch.Tensor, # shape [n_pairs, 1] + self, + x: torch.Tensor, + pairlist: torch.Tensor, # shape [n_pairs, 2] + filters: torch.Tensor, # shape [n_pairs, 1, number_of_radial_basis_functions] TODO: why the 1? + dir_ij: torch.Tensor, # shape [n_pairs, 1] + d_orbital_ij: torch.Tensor, # shape [n_pairs, 1] ) -> Tuple[torch.Tensor, torch.Tensor]: """ Evaluate all modules in the block. From b7a2b2894bb1fd2dbea5b6f2386ff472249c7aa4 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Thu, 11 Jul 2024 19:35:38 -0700 Subject: [PATCH 64/78] Fix ini_alpha with units --- modelforge/potential/spookynet.py | 1 - modelforge/potential/utils.py | 7 +++++-- modelforge/tests/test_spookynet.py | 3 +-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/modelforge/potential/spookynet.py b/modelforge/potential/spookynet.py index 47295d9c..afa89295 100644 --- a/modelforge/potential/spookynet.py +++ b/modelforge/potential/spookynet.py @@ -319,7 +319,6 @@ def __init__( self.radial_symmetry_function_module = ExponentialBernsteinRadialBasisFunction( number_of_radial_basis_functions=number_of_radial_basis_functions, - ini_alpha=1.0, # TODO: put the right number dtype=torch.float32, ) diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index 188593fd..1fa5645c 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -970,7 +970,10 @@ def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: class ExponentialBernsteinRadialBasisFunction(RadialBasisFunction): - def __init__(self, number_of_radial_basis_functions, ini_alpha, dtype=torch.int64): + def __init__(self, + number_of_radial_basis_functions: int, + ini_alpha: unit.Quantity = 2.0 * unit.bohr, + dtype=torch.int64): """ ini_alpha (float): Initial value for scaling parameter alpha (alpha here is the reciprocal of alpha in the paper. The original @@ -981,7 +984,7 @@ def __init__(self, number_of_radial_basis_functions, ini_alpha, dtype=torch.int6 trainable_prefactor=False, dtype=dtype, ) - self.alpha = ini_alpha #TODO: should this be unitful? + self.register_parameter("alpha", nn.Parameter(torch.tensor(ini_alpha.m_as(unit.nanometer)))) def nondimensionalize_distances(self, d_ij: torch.Tensor) -> torch.Tensor: return -( diff --git a/modelforge/tests/test_spookynet.py b/modelforge/tests/test_spookynet.py index 205f7306..1cd5aa72 100644 --- a/modelforge/tests/test_spookynet.py +++ b/modelforge/tests/test_spookynet.py @@ -238,8 +238,7 @@ def test_spookynet_bernstein_polynomial_equivalence(): num_basis_functions = 3 ref_exp_bernstein_polynomials = RefExponentialBernsteinPolynomials(num_basis_functions, exp_weighting=True) - mf_exp_bernstein_polynomials = MfExponentialBernSteinPolynomials(num_basis_functions, ini_alpha=1.0) # TODO: put - # the right number + mf_exp_bernstein_polynomials = MfExponentialBernSteinPolynomials(num_basis_functions) N = 5 r_angstrom = torch.rand((N, 1)) From 731bd7b40eab17d3d9ad15f465877655ba3efc79 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Mon, 15 Jul 2024 15:53:58 -0700 Subject: [PATCH 65/78] Trying to implement embeddings --- modelforge/potential/spookynet.py | 99 +++++++++++++++++++++++++++++- modelforge/tests/test_spookynet.py | 25 ++++++++ 2 files changed, 122 insertions(+), 2 deletions(-) diff --git a/modelforge/potential/spookynet.py b/modelforge/potential/spookynet.py index afa89295..4541848d 100644 --- a/modelforge/potential/spookynet.py +++ b/modelforge/potential/spookynet.py @@ -48,6 +48,12 @@ class SpookyNetNeuralNetworkData(NeuralNetworkData): atomic_embedding : torch.Tensor A 2D tensor containing embeddings or features for each atom, derived from atomic numbers. Shape: [num_atoms, embedding_dim], where `embedding_dim` is the dimensionality of the embedding vectors. + charge_embedding : torch.Tensor + A 2D tensor containing embeddings or features for each atom, derived from total charge. + Shape: [num_atoms, embedding_dim], where `embedding_dim` is the dimensionality of the embedding vectors. + magmom_embedding : torch.Tensor + A 2D tensor containing embeddings or features for each atom, derived from spin states. + Shape: [num_atoms, embedding_dim], where `embedding_dim` is the dimensionality of the embedding vectors. f_ij : Optional[torch.Tensor] A tensor representing the radial symmetry function expansion of distances between atom pairs, capturing the local chemical environment. Shape: [num_pairs, number_of_atom_features], where `number_of_atom_features` is the dimensionality of @@ -83,6 +89,8 @@ class SpookyNetNeuralNetworkData(NeuralNetworkData): """ atomic_embedding: torch.Tensor + charge_embedding: torch.Tensor + magmom_embedding: torch.Tensor f_ij: Optional[torch.Tensor] = field(default=None) f_cutoff: Optional[torch.Tensor] = field(default=None) @@ -121,7 +129,9 @@ def __init__( # embedding from modelforge.potential.utils import Embedding - self.embedding_module = Embedding(max_Z, number_of_atom_features) + self.atomic_embedding_module = Embedding(max_Z, number_of_atom_features) + self.charge_embedding_module = Embedding(max_Z, number_of_atom_features) + self.magmom_embedding_module = Embedding(max_Z, number_of_atom_features) # initialize representation block self.spookynet_representation_module = SpookyNetRepresentation( @@ -202,7 +212,7 @@ def compute_properties( # Compute the representation for each atom (transform to radial basis set, multiply by cutoff) representation = self.spookynet_representation_module(data.d_ij, data.r_ij) - x = data.atomic_embedding + x = data.atomic_embedding + data.charge_embedding + data.magmom_embedding f = x.new_zeros(x.size()) # initialize output features to zero # Iterate over interaction blocks to update features @@ -292,6 +302,91 @@ def _config_prior(self): return prior +class ElectronicEmbedding(nn.Module): + """ + Block for updating atomic features through nonlocal interactions with the + electrons. + + Arguments: + num_features (int): + Dimensions of feature space. + num_basis_functions (int): + Number of radial basis functions. + num_residual_pre_i (int): + Number of residual blocks applied to atomic features in i branch + (central atoms) before computing the interaction. + num_residual_pre_j (int): + Number of residual blocks applied to atomic features in j branch + (neighbouring atoms) before computing the interaction. + num_residual_post (int): + Number of residual blocks applied to interaction features. + activation (str): + Kind of activation function. Possible values: + 'swish': Swish activation function. + 'ssp': Shifted softplus activation function. + """ + + def __init__( + self, + num_features: int, + num_residual: int, + activation: str = "swish", + is_charge: bool = False, + ) -> None: + """ Initializes the ElectronicEmbedding class. """ + super(ElectronicEmbedding, self).__init__() + self.is_charge = is_charge + self.linear_q = nn.Linear(num_features, num_features) + if is_charge: # charges are duplicated to use separate weights for +/- + self.linear_k = nn.Linear(2, num_features, bias=False) + self.linear_v = nn.Linear(2, num_features, bias=False) + else: + self.linear_k = nn.Linear(1, num_features, bias=False) + self.linear_v = nn.Linear(1, num_features, bias=False) + self.resblock = SpookyNetResidualMLP( + num_features, + num_residual, + bias=False, + ) + self.reset_parameters() + + def reset_parameters(self) -> None: + """ Initialize parameters. """ + nn.init.orthogonal_(self.linear_k.weight) + nn.init.orthogonal_(self.linear_v.weight) + nn.init.orthogonal_(self.linear_q.weight) + nn.init.zeros_(self.linear_q.bias) + + def forward( + self, + x: torch.Tensor, + E: torch.Tensor, + num_batch: int, + eps: float = 1e-8, + ) -> torch.Tensor: + """ + Evaluate interaction block. + N: Number of atoms. + + x (FloatTensor [N, num_features]): + Atomic feature vectors. + """ + batch_seg = torch.zeros(x.size(0), dtype=torch.int64, device=x.device) + q = self.linear_q(x) # queries + if self.is_charge: + e = F.relu(torch.stack([E, -E], dim=-1)) + else: + e = torch.abs(E).unsqueeze(-1) # +/- spin is the same => abs + enorm = torch.maximum(e, torch.ones_like(e)) + k = self.linear_k(e / enorm)[batch_seg] # keys + v = self.linear_v(e)[batch_seg] # values + dot = torch.sum(k * q, dim=-1) / k.shape[-1] ** 0.5 # scaled dot product + a = nn.functional.softplus(dot) # unnormalized attention weights + anorm = a.new_zeros(num_batch).index_add_(0, batch_seg, a) + anorm = anorm[batch_seg] + return self.resblock((a / (anorm + eps)).unsqueeze(-1) * v) + + class SpookyNetRepresentation(nn.Module): def __init__( diff --git a/modelforge/tests/test_spookynet.py b/modelforge/tests/test_spookynet.py index 1cd5aa72..b86e0e1b 100644 --- a/modelforge/tests/test_spookynet.py +++ b/modelforge/tests/test_spookynet.py @@ -63,6 +63,31 @@ def test_forward(): ) calculated_results = spookynet.core_module.forward(model_input, pairlist_output) + ref_spookynet = RefSpookyNet( + num_features=config["potential"]["core_parameter"]["number_of_atom_features"], + num_basis_functions=config["potential"]["core_parameter"]["number_of_radial_basis_functions"], + num_modules=config["potential"]["core_parameter"]["number_of_interaction_modules"], + num_residual_electron=config["potential"]["core_parameter"]["number_of_residual_blocks"], + num_residual_pre=config["potential"]["core_parameter"]["number_of_residual_blocks"], + num_residual_local_x=config["potential"]["core_parameter"]["number_of_residual_blocks"], + num_residual_local_s=config["potential"]["core_parameter"]["number_of_residual_blocks"], + num_residual_local_p=config["potential"]["core_parameter"]["number_of_residual_blocks"], + num_residual_local_d=config["potential"]["core_parameter"]["number_of_residual_blocks"], + num_residual_local=config["potential"]["core_parameter"]["number_of_residual_blocks"], + num_residual_nonlocal_q=config["potential"]["core_parameter"]["number_of_residual_blocks"], + num_residual_nonlocal_k=config["potential"]["core_parameter"]["number_of_residual_blocks"], + num_residual_nonlocal_v=config["potential"]["core_parameter"]["number_of_residual_blocks"], + num_residual_post=config["potential"]["core_parameter"]["number_of_residual_blocks"], + num_residual_output=config["potential"]["core_parameter"]["number_of_residual_blocks"], + ) + + ref_spookynet( + prepared_input["atomic_numbers"], + prepared_input["total_charge"], + + + ) + def test_spookynet_forward(single_batch_with_batchsize_64, model_parameter): """ From 71c3feaae052eae5e69f2c6a7c12270d4c8932e7 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Mon, 15 Jul 2024 23:03:30 -0700 Subject: [PATCH 66/78] More changes --- modelforge/potential/spookynet.py | 59 ++++++++++-------------------- modelforge/tests/test_spookynet.py | 9 +++-- 2 files changed, 24 insertions(+), 44 deletions(-) diff --git a/modelforge/potential/spookynet.py b/modelforge/potential/spookynet.py index 4541848d..26528c05 100644 --- a/modelforge/potential/spookynet.py +++ b/modelforge/potential/spookynet.py @@ -51,9 +51,6 @@ class SpookyNetNeuralNetworkData(NeuralNetworkData): charge_embedding : torch.Tensor A 2D tensor containing embeddings or features for each atom, derived from total charge. Shape: [num_atoms, embedding_dim], where `embedding_dim` is the dimensionality of the embedding vectors. - magmom_embedding : torch.Tensor - A 2D tensor containing embeddings or features for each atom, derived from spin states. - Shape: [num_atoms, embedding_dim], where `embedding_dim` is the dimensionality of the embedding vectors. f_ij : Optional[torch.Tensor] A tensor representing the radial symmetry function expansion of distances between atom pairs, capturing the local chemical environment. Shape: [num_pairs, number_of_atom_features], where `number_of_atom_features` is the dimensionality of @@ -90,7 +87,6 @@ class SpookyNetNeuralNetworkData(NeuralNetworkData): atomic_embedding: torch.Tensor charge_embedding: torch.Tensor - magmom_embedding: torch.Tensor f_ij: Optional[torch.Tensor] = field(default=None) f_cutoff: Optional[torch.Tensor] = field(default=None) @@ -130,8 +126,7 @@ def __init__( from modelforge.potential.utils import Embedding self.atomic_embedding_module = Embedding(max_Z, number_of_atom_features) - self.charge_embedding_module = Embedding(max_Z, number_of_atom_features) - self.magmom_embedding_module = Embedding(max_Z, number_of_atom_features) + self.charge_embedding_module = ElectronicEmbedding(number_of_atom_features, number_of_residual_blocks) # initialize representation block self.spookynet_representation_module = SpookyNetRepresentation( @@ -178,6 +173,9 @@ def _model_specific_input_preparation( ) -> SpookyNetNeuralNetworkData: number_of_atoms = data.atomic_numbers.shape[0] + atomic_embedding = self.atomic_embedding_module(data.atomic_numbers) + charge_embedding = self.charge_embedding_module(atomic_embedding, data.total_charge, num_batch=1) # TODO: what is num_batch? + nnp_input = SpookyNetNeuralNetworkData( pair_indices=pairlist_output.pair_indices, d_ij=pairlist_output.d_ij, @@ -187,9 +185,8 @@ def _model_specific_input_preparation( atomic_numbers=data.atomic_numbers, atomic_subsystem_indices=data.atomic_subsystem_indices, total_charge=data.total_charge, - atomic_embedding=self.embedding_module( - data.atomic_numbers - ), # atom embedding + atomic_embedding=atomic_embedding, + charge_embedding=charge_embedding, ) return nnp_input @@ -212,7 +209,7 @@ def compute_properties( # Compute the representation for each atom (transform to radial basis set, multiply by cutoff) representation = self.spookynet_representation_module(data.d_ij, data.r_ij) - x = data.atomic_embedding + data.charge_embedding + data.magmom_embedding + x = data.atomic_embedding + data.charge_embedding f = x.new_zeros(x.size()) # initialize output features to zero # Iterate over interaction blocks to update features @@ -310,39 +307,21 @@ class ElectronicEmbedding(nn.Module): Arguments: num_features (int): Dimensions of feature space. - num_basis_functions (int): - Number of radial basis functions. - num_residual_pre_i (int): - Number of residual blocks applied to atomic features in i branch - (central atoms) before computing the interaction. - num_residual_pre_j (int): - Number of residual blocks applied to atomic features in j branch - (neighbouring atoms) before computing the interaction. - num_residual_post (int): - Number of residual blocks applied to interaction features. - activation (str): - Kind of activation function. Possible values: - 'swish': Swish activation function. - 'ssp': Shifted softplus activation function. + num_residual (int): + TODO """ def __init__( self, num_features: int, num_residual: int, - activation: str = "swish", - is_charge: bool = False, ) -> None: """ Initializes the ElectronicEmbedding class. """ - super(ElectronicEmbedding, self).__init__() - self.is_charge = is_charge + super().__init__() self.linear_q = nn.Linear(num_features, num_features) - if is_charge: # charges are duplicated to use separate weights for +/- - self.linear_k = nn.Linear(2, num_features, bias=False) - self.linear_v = nn.Linear(2, num_features, bias=False) - else: - self.linear_k = nn.Linear(1, num_features, bias=False) - self.linear_v = nn.Linear(1, num_features, bias=False) + # charges are duplicated to use separate weights for +/- + self.linear_k = nn.Linear(2, num_features, bias=False, dtype=torch.float32) + self.linear_v = nn.Linear(2, num_features, bias=False) self.resblock = SpookyNetResidualMLP( num_features, num_residual, @@ -373,12 +352,12 @@ def forward( """ batch_seg = torch.zeros(x.size(0), dtype=torch.int64, device=x.device) q = self.linear_q(x) # queries - if self.is_charge: - e = F.relu(torch.stack([E, -E], dim=-1)) - else: - e = torch.abs(E).unsqueeze(-1) # +/- spin is the same => abs - enorm = torch.maximum(e, torch.ones_like(e)) - k = self.linear_k(e / enorm)[batch_seg] # keys + e = F.relu(torch.stack([E, -E], dim=-1)).double() + enorm = torch.maximum(e, torch.ones_like(e, dtype=torch.float64)) + test = e / enorm + print(f"{test.dtype=}") + print(f"{self.linear_k.weight.dtype=}") + k = self.linear_k(test)[batch_seg] # keys v = self.linear_v(e)[batch_seg] # values dot = torch.sum(k * q, dim=-1) / k.shape[-1] ** 0.5 # scaled dot product a = nn.functional.softplus(dot) # unnormalized attention weights diff --git a/modelforge/tests/test_spookynet.py b/modelforge/tests/test_spookynet.py index b86e0e1b..1544cc42 100644 --- a/modelforge/tests/test_spookynet.py +++ b/modelforge/tests/test_spookynet.py @@ -82,10 +82,11 @@ def test_forward(): ) ref_spookynet( - prepared_input["atomic_numbers"], - prepared_input["total_charge"], - - + prepared_input.atomic_numbers, + prepared_input.total_charge, + prepared_input.positions, + prepared_input.pair_indices[0], + prepared_input.pair_indices[1], ) From e548bf1264c3f959506d4d79544c9023d06994fd Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Fri, 19 Jul 2024 18:31:50 -0700 Subject: [PATCH 67/78] More changes --- modelforge/potential/spookynet.py | 107 +++++++++++++++++++++++++++++- 1 file changed, 105 insertions(+), 2 deletions(-) diff --git a/modelforge/potential/spookynet.py b/modelforge/potential/spookynet.py index 26528c05..1df27bcf 100644 --- a/modelforge/potential/spookynet.py +++ b/modelforge/potential/spookynet.py @@ -126,6 +126,14 @@ def __init__( from modelforge.potential.utils import Embedding self.atomic_embedding_module = Embedding(max_Z, number_of_atom_features) + self.electronic_embedding_module = nn.Linear(electron_config.shape[1], number_of_atom_features, bias=False) + self.register_buffer("electron_config", torch.tensor(electron_config)) + self.register_parameter( + "atomic_embedding_bias", nn.Parameter(torch.Tensor(max_Z, self.num_features)) + ) + self.atomic_embedding_weight = nn.Linear( + self.electron_config.size(1), self.num_features, bias=False + ) self.charge_embedding_module = ElectronicEmbedding(number_of_atom_features, number_of_residual_blocks) # initialize representation block @@ -174,6 +182,7 @@ def _model_specific_input_preparation( number_of_atoms = data.atomic_numbers.shape[0] atomic_embedding = self.atomic_embedding_module(data.atomic_numbers) + charge_embedding = self.charge_embedding_module(atomic_embedding, data.total_charge, num_batch=1) # TODO: what is num_batch? nnp_input = SpookyNetNeuralNetworkData( @@ -355,8 +364,6 @@ def forward( e = F.relu(torch.stack([E, -E], dim=-1)).double() enorm = torch.maximum(e, torch.ones_like(e, dtype=torch.float64)) test = e / enorm - print(f"{test.dtype=}") - print(f"{self.linear_k.weight.dtype=}") k = self.linear_k(test)[batch_seg] # keys v = self.linear_v(e)[batch_seg] # values dot = torch.sum(k * q, dim=-1) / k.shape[-1] ** 0.5 # scaled dot product @@ -1042,3 +1049,99 @@ def forward( x_updated = self.residual_post(x_tilde + l + n) del x_tilde return x_updated, self.resblock(x_updated) + + +# fmt: off +electron_config = np.array([ + # Z 1s 2s 2p 3s 3p 4s 3d 4p 5s 4d 5p 6s 4f 5d 6p vs vp vd vf + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # n + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # H + [2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], # He + [3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # Li + [4, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], # Be + [5, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0], # B + [6, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0], # C + [7, 2, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 0, 0], # N + [8, 2, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 4, 0, 0], # O + [9, 2, 2, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 5, 0, 0], # F + [10, 2, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 6, 0, 0], # Ne + [11, 2, 2, 6, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # Na + [12, 2, 2, 6, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], # Mg + [13, 2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0], # Al + [14, 2, 2, 6, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0], # Si + [15, 2, 2, 6, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 0, 0], # P + [16, 2, 2, 6, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 4, 0, 0], # S + [17, 2, 2, 6, 2, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 5, 0, 0], # Cl + [18, 2, 2, 6, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 6, 0, 0], # Ar + [19, 2, 2, 6, 2, 6, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # K + [20, 2, 2, 6, 2, 6, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], # Ca + [21, 2, 2, 6, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 1, 0], # Sc + [22, 2, 2, 6, 2, 6, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 2, 0], # Ti + [23, 2, 2, 6, 2, 6, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 3, 0], # V + [24, 2, 2, 6, 2, 6, 1, 5, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 5, 0], # Cr + [25, 2, 2, 6, 2, 6, 2, 5, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 5, 0], # Mn + [26, 2, 2, 6, 2, 6, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 6, 0], # Fe + [27, 2, 2, 6, 2, 6, 2, 7, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 7, 0], # Co + [28, 2, 2, 6, 2, 6, 2, 8, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 8, 0], # Ni + [29, 2, 2, 6, 2, 6, 1, 10, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 10, 0], # Cu + [30, 2, 2, 6, 2, 6, 2, 10, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 10, 0], # Zn + [31, 2, 2, 6, 2, 6, 2, 10, 1, 0, 0, 0, 0, 0, 0, 0, 2, 1, 10, 0], # Ga + [32, 2, 2, 6, 2, 6, 2, 10, 2, 0, 0, 0, 0, 0, 0, 0, 2, 2, 10, 0], # Ge + [33, 2, 2, 6, 2, 6, 2, 10, 3, 0, 0, 0, 0, 0, 0, 0, 2, 3, 10, 0], # As + [34, 2, 2, 6, 2, 6, 2, 10, 4, 0, 0, 0, 0, 0, 0, 0, 2, 4, 10, 0], # Se + [35, 2, 2, 6, 2, 6, 2, 10, 5, 0, 0, 0, 0, 0, 0, 0, 2, 5, 10, 0], # Br + [36, 2, 2, 6, 2, 6, 2, 10, 6, 0, 0, 0, 0, 0, 0, 0, 2, 6, 10, 0], # Kr + [37, 2, 2, 6, 2, 6, 2, 10, 6, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # Rb + [38, 2, 2, 6, 2, 6, 2, 10, 6, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], # Sr + [39, 2, 2, 6, 2, 6, 2, 10, 6, 2, 1, 0, 0, 0, 0, 0, 2, 0, 1, 0], # Y + [40, 2, 2, 6, 2, 6, 2, 10, 6, 2, 2, 0, 0, 0, 0, 0, 2, 0, 2, 0], # Zr + [41, 2, 2, 6, 2, 6, 2, 10, 6, 1, 4, 0, 0, 0, 0, 0, 1, 0, 4, 0], # Nb + [42, 2, 2, 6, 2, 6, 2, 10, 6, 1, 5, 0, 0, 0, 0, 0, 1, 0, 5, 0], # Mo + [43, 2, 2, 6, 2, 6, 2, 10, 6, 2, 5, 0, 0, 0, 0, 0, 2, 0, 5, 0], # Tc + [44, 2, 2, 6, 2, 6, 2, 10, 6, 1, 7, 0, 0, 0, 0, 0, 1, 0, 7, 0], # Ru + [45, 2, 2, 6, 2, 6, 2, 10, 6, 1, 8, 0, 0, 0, 0, 0, 1, 0, 8, 0], # Rh + [46, 2, 2, 6, 2, 6, 2, 10, 6, 0, 10, 0, 0, 0, 0, 0, 0, 0, 10, 0], # Pd + [47, 2, 2, 6, 2, 6, 2, 10, 6, 1, 10, 0, 0, 0, 0, 0, 1, 0, 10, 0], # Ag + [48, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 0, 0, 0, 0, 0, 2, 0, 10, 0], # Cd + [49, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 1, 0, 0, 0, 0, 2, 1, 10, 0], # In + [50, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 2, 0, 0, 0, 0, 2, 2, 10, 0], # Sn + [51, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 3, 0, 0, 0, 0, 2, 3, 10, 0], # Sb + [52, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 4, 0, 0, 0, 0, 2, 4, 10, 0], # Te + [53, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 5, 0, 0, 0, 0, 2, 5, 10, 0], # I + [54, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 0, 0, 0, 0, 2, 6, 10, 0], # Xe + [55, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 1, 0, 0, 0, 1, 0, 0, 0], # Cs + [56, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 0, 0, 0, 2, 0, 0, 0], # Ba + [57, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 0, 1, 0, 2, 0, 1, 0], # La + [58, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 1, 1, 0, 2, 0, 1, 1], # Ce + [59, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 3, 0, 0, 2, 0, 0, 3], # Pr + [60, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 4, 0, 0, 2, 0, 0, 4], # Nd + [61, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 5, 0, 0, 2, 0, 0, 5], # Pm + [62, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 6, 0, 0, 2, 0, 0, 6], # Sm + [63, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 7, 0, 0, 2, 0, 0, 7], # Eu + [64, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 7, 1, 0, 2, 0, 1, 7], # Gd + [65, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 9, 0, 0, 2, 0, 0, 9], # Tb + [66, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 10, 0, 0, 2, 0, 0, 10], # Dy + [67, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 11, 0, 0, 2, 0, 0, 11], # Ho + [68, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 12, 0, 0, 2, 0, 0, 12], # Er + [69, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 13, 0, 0, 2, 0, 0, 13], # Tm + [70, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 0, 0, 2, 0, 0, 14], # Yb + [71, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 1, 0, 2, 0, 1, 14], # Lu + [72, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 2, 0, 2, 0, 2, 14], # Hf + [73, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 3, 0, 2, 0, 3, 14], # Ta + [74, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 4, 0, 2, 0, 4, 14], # W + [75, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 5, 0, 2, 0, 5, 14], # Re + [76, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 6, 0, 2, 0, 6, 14], # Os + [77, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 7, 0, 2, 0, 7, 14], # Ir + [78, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 1, 14, 9, 0, 1, 0, 9, 14], # Pt + [79, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 1, 14, 10, 0, 1, 0, 10, 14], # Au + [80, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 0, 2, 0, 10, 14], # Hg + [81, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 1, 2, 1, 10, 14], # Tl + [82, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 2, 2, 2, 10, 14], # Pb + [83, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 3, 2, 3, 10, 14], # Bi + [84, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 4, 2, 4, 10, 14], # Po + [85, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 5, 2, 5, 10, 14], # At + [86, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 6, 2, 6, 10, 14] # Rn +], dtype=np.float64) +# fmt: on +# normalize entries (between 0.0 and 1.0) +electron_config = electron_config / np.max(electron_config, axis=0) From 5ea8988874b2ee63eaca18701b9167d8de0859c7 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Tue, 23 Jul 2024 09:32:48 -0700 Subject: [PATCH 68/78] Fix atomic embedding --- modelforge/potential/spookynet.py | 218 +++++++++--------- .../data/potential_defaults/spookynet.toml | 2 +- 2 files changed, 114 insertions(+), 106 deletions(-) diff --git a/modelforge/potential/spookynet.py b/modelforge/potential/spookynet.py index 1df27bcf..b83d6421 100644 --- a/modelforge/potential/spookynet.py +++ b/modelforge/potential/spookynet.py @@ -94,7 +94,7 @@ class SpookyNetNeuralNetworkData(NeuralNetworkData): class SpookyNetCore(CoreNetwork): def __init__( self, - max_Z: int = 100, + max_Z: int = 87, # need to update electron_config if we want to use higher atomic numbers number_of_atom_features: int = 64, number_of_radial_basis_functions: int = 20, number_of_interaction_modules: int = 3, @@ -106,7 +106,7 @@ def __init__( Parameters ---------- - max_Z : int, default=100 + max_Z : int, default=87 Maximum atomic number to be embedded. number_of_atom_features : int, default=64 Dimension of the embedding vectors for atomic numbers. @@ -125,15 +125,8 @@ def __init__( # embedding from modelforge.potential.utils import Embedding - self.atomic_embedding_module = Embedding(max_Z, number_of_atom_features) - self.electronic_embedding_module = nn.Linear(electron_config.shape[1], number_of_atom_features, bias=False) - self.register_buffer("electron_config", torch.tensor(electron_config)) - self.register_parameter( - "atomic_embedding_bias", nn.Parameter(torch.Tensor(max_Z, self.num_features)) - ) - self.atomic_embedding_weight = nn.Linear( - self.electron_config.size(1), self.num_features, bias=False - ) + assert max_Z <= 87 + self.atomic_embedding_module = SpookyNetAtomicEmbedding(number_of_atom_features, max_Z) self.charge_embedding_module = ElectronicEmbedding(number_of_atom_features, number_of_residual_blocks) # initialize representation block @@ -244,6 +237,115 @@ def compute_properties( from .models import InputPreparation, NNPInput, BaseNetwork +class SpookyNetAtomicEmbedding(nn.Module): + + def __init__(self, number_of_atom_features, max_Z): + super().__init__() + + # fmt: off + electron_config = np.array([ + # Z 1s 2s 2p 3s 3p 4s 3d 4p 5s 4d 5p 6s 4f 5d 6p vs vp vd vf + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # n + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # H + [2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], # He + [3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # Li + [4, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], # Be + [5, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0], # B + [6, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0], # C + [7, 2, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 0, 0], # N + [8, 2, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 4, 0, 0], # O + [9, 2, 2, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 5, 0, 0], # F + [10, 2, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 6, 0, 0], # Ne + [11, 2, 2, 6, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # Na + [12, 2, 2, 6, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], # Mg + [13, 2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0], # Al + [14, 2, 2, 6, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0], # Si + [15, 2, 2, 6, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 0, 0], # P + [16, 2, 2, 6, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 4, 0, 0], # S + [17, 2, 2, 6, 2, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 5, 0, 0], # Cl + [18, 2, 2, 6, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 6, 0, 0], # Ar + [19, 2, 2, 6, 2, 6, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # K + [20, 2, 2, 6, 2, 6, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], # Ca + [21, 2, 2, 6, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 1, 0], # Sc + [22, 2, 2, 6, 2, 6, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 2, 0], # Ti + [23, 2, 2, 6, 2, 6, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 3, 0], # V + [24, 2, 2, 6, 2, 6, 1, 5, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 5, 0], # Cr + [25, 2, 2, 6, 2, 6, 2, 5, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 5, 0], # Mn + [26, 2, 2, 6, 2, 6, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 6, 0], # Fe + [27, 2, 2, 6, 2, 6, 2, 7, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 7, 0], # Co + [28, 2, 2, 6, 2, 6, 2, 8, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 8, 0], # Ni + [29, 2, 2, 6, 2, 6, 1, 10, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 10, 0], # Cu + [30, 2, 2, 6, 2, 6, 2, 10, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 10, 0], # Zn + [31, 2, 2, 6, 2, 6, 2, 10, 1, 0, 0, 0, 0, 0, 0, 0, 2, 1, 10, 0], # Ga + [32, 2, 2, 6, 2, 6, 2, 10, 2, 0, 0, 0, 0, 0, 0, 0, 2, 2, 10, 0], # Ge + [33, 2, 2, 6, 2, 6, 2, 10, 3, 0, 0, 0, 0, 0, 0, 0, 2, 3, 10, 0], # As + [34, 2, 2, 6, 2, 6, 2, 10, 4, 0, 0, 0, 0, 0, 0, 0, 2, 4, 10, 0], # Se + [35, 2, 2, 6, 2, 6, 2, 10, 5, 0, 0, 0, 0, 0, 0, 0, 2, 5, 10, 0], # Br + [36, 2, 2, 6, 2, 6, 2, 10, 6, 0, 0, 0, 0, 0, 0, 0, 2, 6, 10, 0], # Kr + [37, 2, 2, 6, 2, 6, 2, 10, 6, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # Rb + [38, 2, 2, 6, 2, 6, 2, 10, 6, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], # Sr + [39, 2, 2, 6, 2, 6, 2, 10, 6, 2, 1, 0, 0, 0, 0, 0, 2, 0, 1, 0], # Y + [40, 2, 2, 6, 2, 6, 2, 10, 6, 2, 2, 0, 0, 0, 0, 0, 2, 0, 2, 0], # Zr + [41, 2, 2, 6, 2, 6, 2, 10, 6, 1, 4, 0, 0, 0, 0, 0, 1, 0, 4, 0], # Nb + [42, 2, 2, 6, 2, 6, 2, 10, 6, 1, 5, 0, 0, 0, 0, 0, 1, 0, 5, 0], # Mo + [43, 2, 2, 6, 2, 6, 2, 10, 6, 2, 5, 0, 0, 0, 0, 0, 2, 0, 5, 0], # Tc + [44, 2, 2, 6, 2, 6, 2, 10, 6, 1, 7, 0, 0, 0, 0, 0, 1, 0, 7, 0], # Ru + [45, 2, 2, 6, 2, 6, 2, 10, 6, 1, 8, 0, 0, 0, 0, 0, 1, 0, 8, 0], # Rh + [46, 2, 2, 6, 2, 6, 2, 10, 6, 0, 10, 0, 0, 0, 0, 0, 0, 0, 10, 0], # Pd + [47, 2, 2, 6, 2, 6, 2, 10, 6, 1, 10, 0, 0, 0, 0, 0, 1, 0, 10, 0], # Ag + [48, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 0, 0, 0, 0, 0, 2, 0, 10, 0], # Cd + [49, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 1, 0, 0, 0, 0, 2, 1, 10, 0], # In + [50, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 2, 0, 0, 0, 0, 2, 2, 10, 0], # Sn + [51, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 3, 0, 0, 0, 0, 2, 3, 10, 0], # Sb + [52, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 4, 0, 0, 0, 0, 2, 4, 10, 0], # Te + [53, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 5, 0, 0, 0, 0, 2, 5, 10, 0], # I + [54, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 0, 0, 0, 0, 2, 6, 10, 0], # Xe + [55, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 1, 0, 0, 0, 1, 0, 0, 0], # Cs + [56, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 0, 0, 0, 2, 0, 0, 0], # Ba + [57, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 0, 1, 0, 2, 0, 1, 0], # La + [58, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 1, 1, 0, 2, 0, 1, 1], # Ce + [59, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 3, 0, 0, 2, 0, 0, 3], # Pr + [60, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 4, 0, 0, 2, 0, 0, 4], # Nd + [61, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 5, 0, 0, 2, 0, 0, 5], # Pm + [62, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 6, 0, 0, 2, 0, 0, 6], # Sm + [63, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 7, 0, 0, 2, 0, 0, 7], # Eu + [64, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 7, 1, 0, 2, 0, 1, 7], # Gd + [65, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 9, 0, 0, 2, 0, 0, 9], # Tb + [66, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 10, 0, 0, 2, 0, 0, 10], # Dy + [67, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 11, 0, 0, 2, 0, 0, 11], # Ho + [68, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 12, 0, 0, 2, 0, 0, 12], # Er + [69, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 13, 0, 0, 2, 0, 0, 13], # Tm + [70, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 0, 0, 2, 0, 0, 14], # Yb + [71, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 1, 0, 2, 0, 1, 14], # Lu + [72, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 2, 0, 2, 0, 2, 14], # Hf + [73, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 3, 0, 2, 0, 3, 14], # Ta + [74, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 4, 0, 2, 0, 4, 14], # W + [75, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 5, 0, 2, 0, 5, 14], # Re + [76, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 6, 0, 2, 0, 6, 14], # Os + [77, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 7, 0, 2, 0, 7, 14], # Ir + [78, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 1, 14, 9, 0, 1, 0, 9, 14], # Pt + [79, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 1, 14, 10, 0, 1, 0, 10, 14], # Au + [80, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 0, 2, 0, 10, 14], # Hg + [81, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 1, 2, 1, 10, 14], # Tl + [82, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 2, 2, 2, 10, 14], # Pb + [83, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 3, 2, 3, 10, 14], # Bi + [84, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 4, 2, 4, 10, 14], # Po + [85, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 5, 2, 5, 10, 14], # At + [86, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 6, 2, 6, 10, 14] # Rn + ], dtype=np.float64) + # fmt: on + # normalize entries (between 0.0 and 1.0) + self.register_buffer("electron_config", torch.tensor(electron_config / np.max(electron_config, axis=0))) + self.register_parameter("atomic_number_weights", + nn.Parameter(torch.zeros((number_of_atom_features, self.electron_config.shape[1])))) + self.atomic_bias = nn.Embedding(max_Z, number_of_atom_features) + + def forward(self, atomic_numbers): + return torch.einsum("fe,ne->nf", self.atomic_number_weights, + self.electron_config[atomic_numbers]) + self.atomic_bias(atomic_numbers) + + + class SpookyNet(BaseNetwork): def __init__( self, @@ -1051,97 +1153,3 @@ def forward( return x_updated, self.resblock(x_updated) -# fmt: off -electron_config = np.array([ - # Z 1s 2s 2p 3s 3p 4s 3d 4p 5s 4d 5p 6s 4f 5d 6p vs vp vd vf - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # n - [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # H - [2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], # He - [3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # Li - [4, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], # Be - [5, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0], # B - [6, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0], # C - [7, 2, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 0, 0], # N - [8, 2, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 4, 0, 0], # O - [9, 2, 2, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 5, 0, 0], # F - [10, 2, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 6, 0, 0], # Ne - [11, 2, 2, 6, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # Na - [12, 2, 2, 6, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], # Mg - [13, 2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0], # Al - [14, 2, 2, 6, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0], # Si - [15, 2, 2, 6, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 0, 0], # P - [16, 2, 2, 6, 2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 4, 0, 0], # S - [17, 2, 2, 6, 2, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 5, 0, 0], # Cl - [18, 2, 2, 6, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 6, 0, 0], # Ar - [19, 2, 2, 6, 2, 6, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # K - [20, 2, 2, 6, 2, 6, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], # Ca - [21, 2, 2, 6, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 1, 0], # Sc - [22, 2, 2, 6, 2, 6, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 2, 0], # Ti - [23, 2, 2, 6, 2, 6, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 3, 0], # V - [24, 2, 2, 6, 2, 6, 1, 5, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 5, 0], # Cr - [25, 2, 2, 6, 2, 6, 2, 5, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 5, 0], # Mn - [26, 2, 2, 6, 2, 6, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 6, 0], # Fe - [27, 2, 2, 6, 2, 6, 2, 7, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 7, 0], # Co - [28, 2, 2, 6, 2, 6, 2, 8, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 8, 0], # Ni - [29, 2, 2, 6, 2, 6, 1, 10, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 10, 0], # Cu - [30, 2, 2, 6, 2, 6, 2, 10, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 10, 0], # Zn - [31, 2, 2, 6, 2, 6, 2, 10, 1, 0, 0, 0, 0, 0, 0, 0, 2, 1, 10, 0], # Ga - [32, 2, 2, 6, 2, 6, 2, 10, 2, 0, 0, 0, 0, 0, 0, 0, 2, 2, 10, 0], # Ge - [33, 2, 2, 6, 2, 6, 2, 10, 3, 0, 0, 0, 0, 0, 0, 0, 2, 3, 10, 0], # As - [34, 2, 2, 6, 2, 6, 2, 10, 4, 0, 0, 0, 0, 0, 0, 0, 2, 4, 10, 0], # Se - [35, 2, 2, 6, 2, 6, 2, 10, 5, 0, 0, 0, 0, 0, 0, 0, 2, 5, 10, 0], # Br - [36, 2, 2, 6, 2, 6, 2, 10, 6, 0, 0, 0, 0, 0, 0, 0, 2, 6, 10, 0], # Kr - [37, 2, 2, 6, 2, 6, 2, 10, 6, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # Rb - [38, 2, 2, 6, 2, 6, 2, 10, 6, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], # Sr - [39, 2, 2, 6, 2, 6, 2, 10, 6, 2, 1, 0, 0, 0, 0, 0, 2, 0, 1, 0], # Y - [40, 2, 2, 6, 2, 6, 2, 10, 6, 2, 2, 0, 0, 0, 0, 0, 2, 0, 2, 0], # Zr - [41, 2, 2, 6, 2, 6, 2, 10, 6, 1, 4, 0, 0, 0, 0, 0, 1, 0, 4, 0], # Nb - [42, 2, 2, 6, 2, 6, 2, 10, 6, 1, 5, 0, 0, 0, 0, 0, 1, 0, 5, 0], # Mo - [43, 2, 2, 6, 2, 6, 2, 10, 6, 2, 5, 0, 0, 0, 0, 0, 2, 0, 5, 0], # Tc - [44, 2, 2, 6, 2, 6, 2, 10, 6, 1, 7, 0, 0, 0, 0, 0, 1, 0, 7, 0], # Ru - [45, 2, 2, 6, 2, 6, 2, 10, 6, 1, 8, 0, 0, 0, 0, 0, 1, 0, 8, 0], # Rh - [46, 2, 2, 6, 2, 6, 2, 10, 6, 0, 10, 0, 0, 0, 0, 0, 0, 0, 10, 0], # Pd - [47, 2, 2, 6, 2, 6, 2, 10, 6, 1, 10, 0, 0, 0, 0, 0, 1, 0, 10, 0], # Ag - [48, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 0, 0, 0, 0, 0, 2, 0, 10, 0], # Cd - [49, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 1, 0, 0, 0, 0, 2, 1, 10, 0], # In - [50, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 2, 0, 0, 0, 0, 2, 2, 10, 0], # Sn - [51, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 3, 0, 0, 0, 0, 2, 3, 10, 0], # Sb - [52, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 4, 0, 0, 0, 0, 2, 4, 10, 0], # Te - [53, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 5, 0, 0, 0, 0, 2, 5, 10, 0], # I - [54, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 0, 0, 0, 0, 2, 6, 10, 0], # Xe - [55, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 1, 0, 0, 0, 1, 0, 0, 0], # Cs - [56, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 0, 0, 0, 2, 0, 0, 0], # Ba - [57, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 0, 1, 0, 2, 0, 1, 0], # La - [58, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 1, 1, 0, 2, 0, 1, 1], # Ce - [59, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 3, 0, 0, 2, 0, 0, 3], # Pr - [60, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 4, 0, 0, 2, 0, 0, 4], # Nd - [61, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 5, 0, 0, 2, 0, 0, 5], # Pm - [62, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 6, 0, 0, 2, 0, 0, 6], # Sm - [63, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 7, 0, 0, 2, 0, 0, 7], # Eu - [64, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 7, 1, 0, 2, 0, 1, 7], # Gd - [65, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 9, 0, 0, 2, 0, 0, 9], # Tb - [66, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 10, 0, 0, 2, 0, 0, 10], # Dy - [67, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 11, 0, 0, 2, 0, 0, 11], # Ho - [68, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 12, 0, 0, 2, 0, 0, 12], # Er - [69, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 13, 0, 0, 2, 0, 0, 13], # Tm - [70, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 0, 0, 2, 0, 0, 14], # Yb - [71, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 1, 0, 2, 0, 1, 14], # Lu - [72, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 2, 0, 2, 0, 2, 14], # Hf - [73, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 3, 0, 2, 0, 3, 14], # Ta - [74, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 4, 0, 2, 0, 4, 14], # W - [75, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 5, 0, 2, 0, 5, 14], # Re - [76, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 6, 0, 2, 0, 6, 14], # Os - [77, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 7, 0, 2, 0, 7, 14], # Ir - [78, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 1, 14, 9, 0, 1, 0, 9, 14], # Pt - [79, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 1, 14, 10, 0, 1, 0, 10, 14], # Au - [80, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 0, 2, 0, 10, 14], # Hg - [81, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 1, 2, 1, 10, 14], # Tl - [82, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 2, 2, 2, 10, 14], # Pb - [83, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 3, 2, 3, 10, 14], # Bi - [84, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 4, 2, 4, 10, 14], # Po - [85, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 5, 2, 5, 10, 14], # At - [86, 2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 6, 2, 6, 10, 14] # Rn -], dtype=np.float64) -# fmt: on -# normalize entries (between 0.0 and 1.0) -electron_config = electron_config / np.max(electron_config, axis=0) diff --git a/modelforge/tests/data/potential_defaults/spookynet.toml b/modelforge/tests/data/potential_defaults/spookynet.toml index a9071aa4..0c7ba898 100644 --- a/modelforge/tests/data/potential_defaults/spookynet.toml +++ b/modelforge/tests/data/potential_defaults/spookynet.toml @@ -2,7 +2,7 @@ model_name = "SpookyNet" [potential.core_parameter] -max_Z = 101 +max_Z = 87 number_of_atom_features = 32 number_of_radial_basis_functions = 20 cutoff = "5.0 angstrom" From 5a60c2c1c2385d00bc8232757301c9def35503be Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Wed, 24 Jul 2024 18:09:07 -0700 Subject: [PATCH 69/78] Fix merge conflict issues --- modelforge/potential/spookynet.py | 9 +-- modelforge/potential/utils.py | 91 ++++++++++++++++++++++++++++++ modelforge/tests/test_schnet.py | 2 + modelforge/tests/test_spookynet.py | 4 +- 4 files changed, 101 insertions(+), 5 deletions(-) diff --git a/modelforge/potential/spookynet.py b/modelforge/potential/spookynet.py index b83d6421..742ff54e 100644 --- a/modelforge/potential/spookynet.py +++ b/modelforge/potential/spookynet.py @@ -375,9 +375,14 @@ def __init__( cutoff : openff.units.unit.Quantity The cutoff distance for interactions. """ + from modelforge.utils.units import _convert + + self.only_unique_pairs = False # NOTE: need to be set before super().__init__ + super().__init__( dataset_statistic=dataset_statistic, postprocessing_parameter=postprocessing_parameter, + cutoff=_convert(cutoff) ) from modelforge.utils.units import _convert @@ -388,10 +393,6 @@ def __init__( number_of_interaction_modules=number_of_interaction_modules, number_of_residual_blocks=number_of_residual_blocks, ) - self.only_unique_pairs = False # NOTE: for pairlist - self.input_preparation = InputPreparation( - cutoff=_convert(cutoff), only_unique_pairs=self.only_unique_pairs - ) def _config_prior(self): log.info("Configuring SpookyNet model hyperparameter prior distribution") diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index 37c7d285..b40c5044 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -513,6 +513,70 @@ def forward(self, nondimensionalized_distances: torch.Tensor) -> torch.Tensor: return torch.exp(-(nondimensionalized_distances**2)) +class ExponentialBernsteinPolynomialsCore(RadialBasisFunctionCore): + """ + Taken from SpookyNet. + Radial basis functions based on exponential Bernstein polynomials given by: + b_{v,n}(x) = (n over v) * exp(-alpha*x)**v * (1-exp(-alpha*x))**(n-v) + (see https://en.wikipedia.org/wiki/Bernstein_polynomial) + Here, n = num_basis_functions-1 and v takes values from 0 to n. This + implementation operates in log space to prevent multiplication of very large + (n over v) and very small numbers (exp(-alpha*x)**v and + (1-exp(-alpha*x))**(n-v)) for numerical stability. + NOTE: There is a problem for x = 0, as log(-expm1(0)) will be log(0) = -inf. + This itself is not an issue, but the buffer v contains an entry 0 and + 0*(-inf)=nan. The correct behaviour could be recovered by replacing the nan + with 0.0, but should not be necessary because issues are only present when + r = 0, which will not occur with chemically meaningful inputs. + + Arguments: + number_of_radial_basis_functions (int): + Number of radial basis functions. + x = infinity. + """ + + def __init__(self, number_of_radial_basis_functions: int): + super().__init__(number_of_radial_basis_functions) + logfactorial = np.zeros(number_of_radial_basis_functions) + for i in range(2, number_of_radial_basis_functions): + logfactorial[i] = logfactorial[i - 1] + np.log(i) + v = np.arange(0, number_of_radial_basis_functions) + n = (number_of_radial_basis_functions - 1) - v + logbinomial = logfactorial[-1] - logfactorial[v] - logfactorial[n] + # register buffers and parameters + dtype = torch.float64 # TODO: make this a parameter + self.logc = torch.tensor(logbinomial, dtype=dtype) + self.n = torch.tensor(n, dtype=dtype) + self.v = torch.tensor(v, dtype=dtype) + + def forward(self, nondimensionalized_distances: torch.Tensor) -> torch.Tensor: + """ + Evaluates radial basis functions given distances + N: Number of input values. + num_basis_functions: Number of radial basis functions. + + Arguments: + nondimensionalized_distances (FloatTensor [N]): + Input distances. + + Returns: + rbf (FloatTensor [N, num_basis_functions]): + Values of the radial basis functions for the distances r. + """ + assert nondimensionalized_distances.ndim == 2 + assert ( + nondimensionalized_distances.shape[1] + == self.number_of_radial_basis_functions + ) + x = ( + self.logc + + (self.n + 1) * nondimensionalized_distances + + self.v * torch.log(-torch.expm1(nondimensionalized_distances)) + ) + + return torch.exp(x) + + class RadialBasisFunction(nn.Module, ABC): def __init__( @@ -926,6 +990,33 @@ def nondimensionalize_distances(self, distances: torch.Tensor) -> torch.Tensor: ) / self.radial_scale_factor +class ExponentialBernsteinRadialBasisFunction(RadialBasisFunction): + + def __init__(self, + number_of_radial_basis_functions: int, + ini_alpha: unit.Quantity = 2.0 * unit.bohr, + dtype=torch.int64): + """ + ini_alpha (float): + Initial value for scaling parameter alpha (alpha here is the reciprocal of alpha in the paper. The original + default is 0.5/bohr, so we use 2 bohr). + """ + super().__init__( + ExponentialBernsteinPolynomialsCore(number_of_radial_basis_functions), + trainable_prefactor=False, + dtype=dtype, + ) + self.register_parameter("alpha", nn.Parameter(torch.tensor(ini_alpha.m_as(unit.nanometer)))) + + def nondimensionalize_distances(self, d_ij: torch.Tensor) -> torch.Tensor: + return -( + d_ij.broadcast_to( + (len(d_ij), self.radial_basis_function.number_of_radial_basis_functions) + ) + / self.alpha + ) + + def pair_list( atomic_subsystem_indices: torch.Tensor, only_unique_pairs: bool = False, diff --git a/modelforge/tests/test_schnet.py b/modelforge/tests/test_schnet.py index ed8b9c06..ded1c8f2 100644 --- a/modelforge/tests/test_schnet.py +++ b/modelforge/tests/test_schnet.py @@ -173,6 +173,8 @@ def test_compare_forward(): config["potential"]["core_parameter"]["number_of_radial_basis_functions"] = 5 config["potential"]["core_parameter"]["number_of_filters"] = 12 + print(f"{config['potential']['core_parameter']}=") + torch.manual_seed(1234) # initialize model diff --git a/modelforge/tests/test_spookynet.py b/modelforge/tests/test_spookynet.py index 1544cc42..93d7d4f5 100644 --- a/modelforge/tests/test_spookynet.py +++ b/modelforge/tests/test_spookynet.py @@ -42,6 +42,8 @@ def test_forward(): config["potential"]["core_parameter"]["number_of_atom_features"] = 12 config["potential"]["core_parameter"]["number_of_radial_basis_functions"] = 7 + print(f"{config['potential']['core_parameter']}=") + torch.manual_seed(1234) # initialize model @@ -79,7 +81,7 @@ def test_forward(): num_residual_nonlocal_v=config["potential"]["core_parameter"]["number_of_residual_blocks"], num_residual_post=config["potential"]["core_parameter"]["number_of_residual_blocks"], num_residual_output=config["potential"]["core_parameter"]["number_of_residual_blocks"], - ) + ).double() ref_spookynet( prepared_input.atomic_numbers, From e4dd91c465e6d0e1df8f2a789b5d613bc7bc98fc Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Wed, 24 Jul 2024 18:38:15 -0700 Subject: [PATCH 70/78] Remove unnecessarily hard-coded arrays in the model implementation and cast everything to double in the test --- modelforge/potential/spookynet.py | 9 ++++----- modelforge/tests/test_spookynet.py | 22 ++++++++++------------ 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/modelforge/potential/spookynet.py b/modelforge/potential/spookynet.py index 742ff54e..b2cc0c5e 100644 --- a/modelforge/potential/spookynet.py +++ b/modelforge/potential/spookynet.py @@ -432,7 +432,7 @@ def __init__( super().__init__() self.linear_q = nn.Linear(num_features, num_features) # charges are duplicated to use separate weights for +/- - self.linear_k = nn.Linear(2, num_features, bias=False, dtype=torch.float32) + self.linear_k = nn.Linear(2, num_features, bias=False) self.linear_v = nn.Linear(2, num_features, bias=False) self.resblock = SpookyNetResidualMLP( num_features, @@ -464,8 +464,8 @@ def forward( """ batch_seg = torch.zeros(x.size(0), dtype=torch.int64, device=x.device) q = self.linear_q(x) # queries - e = F.relu(torch.stack([E, -E], dim=-1)).double() - enorm = torch.maximum(e, torch.ones_like(e, dtype=torch.float64)) + e = F.relu(torch.stack([E, -E], dim=-1)) + enorm = torch.maximum(e, torch.ones_like(e)) test = e / enorm k = self.linear_k(test)[batch_seg] # keys v = self.linear_v(e)[batch_seg] # values @@ -503,7 +503,6 @@ def __init__( self.radial_symmetry_function_module = ExponentialBernsteinRadialBasisFunction( number_of_radial_basis_functions=number_of_radial_basis_functions, - dtype=torch.float32, ) self.cutoff_module = CosineCutoff(cutoff=cutoff) @@ -887,7 +886,7 @@ def __init__(self, dim_qk: int, num_random_features: int) -> None: super(SpookyNetAttention, self).__init__() self.num_random_features = num_random_features omega = self._omega(num_random_features, dim_qk) - self.register_buffer("omega", torch.tensor(omega, dtype=torch.float32)) + self.register_buffer("omega", torch.tensor(omega)) self.reset_parameters() def reset_parameters(self) -> None: diff --git a/modelforge/tests/test_spookynet.py b/modelforge/tests/test_spookynet.py index 93d7d4f5..a13cf642 100644 --- a/modelforge/tests/test_spookynet.py +++ b/modelforge/tests/test_spookynet.py @@ -4,7 +4,7 @@ setup_single_methane_input, ) import torch - +from icecream import ic import pytest @@ -54,15 +54,12 @@ def test_forward(): input = setup_single_methane_input() model_input = input["modelforge_methane_input"] - + model_input.positions = model_input.positions.double() + model_input.total_charge = model_input.total_charge.double() spookynet.input_preparation._input_checks(model_input) pairlist_output = spookynet.input_preparation.prepare_inputs(model_input) - print(f"{pairlist_output.d_ij.shape=}") - prepared_input = spookynet.core_module._model_specific_input_preparation( - model_input, pairlist_output - ) calculated_results = spookynet.core_module.forward(model_input, pairlist_output) ref_spookynet = RefSpookyNet( @@ -84,11 +81,11 @@ def test_forward(): ).double() ref_spookynet( - prepared_input.atomic_numbers, - prepared_input.total_charge, - prepared_input.positions, - prepared_input.pair_indices[0], - prepared_input.pair_indices[1], + model_input.atomic_numbers, + model_input.total_charge, + model_input.positions, + pairlist_output.pair_indices[0], + pairlist_output.pair_indices[1], ) @@ -261,7 +258,8 @@ def test_spookynet_interaction_module_against_reference(): def test_spookynet_bernstein_polynomial_equivalence(): - from spookynet.modules.exponential_bernstein_polynomials import ExponentialBernsteinPolynomials as RefExponentialBernsteinPolynomials + from spookynet.modules.exponential_bernstein_polynomials import \ + ExponentialBernsteinPolynomials as RefExponentialBernsteinPolynomials from modelforge.potential.utils import ExponentialBernsteinRadialBasisFunction as MfExponentialBernSteinPolynomials num_basis_functions = 3 From 5e00a6d968c0a070415a36030c64346d7ee96dd2 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Thu, 25 Jul 2024 09:49:04 -0700 Subject: [PATCH 71/78] Changes to tests and model --- modelforge/potential/spookynet.py | 89 +++++++++++++++++------------- modelforge/tests/test_spookynet.py | 10 ++++ 2 files changed, 61 insertions(+), 38 deletions(-) diff --git a/modelforge/potential/spookynet.py b/modelforge/potential/spookynet.py index b2cc0c5e..9a1567b8 100644 --- a/modelforge/potential/spookynet.py +++ b/modelforge/potential/spookynet.py @@ -16,6 +16,7 @@ from modelforge.potential.utils import NNPInput from modelforge.potential.utils import NeuralNetworkData +from icecream import ic @dataclass @@ -126,8 +127,12 @@ def __init__( from modelforge.potential.utils import Embedding assert max_Z <= 87 - self.atomic_embedding_module = SpookyNetAtomicEmbedding(number_of_atom_features, max_Z) - self.charge_embedding_module = ElectronicEmbedding(number_of_atom_features, number_of_residual_blocks) + self.atomic_embedding_module = SpookyNetAtomicEmbedding( + number_of_atom_features, max_Z + ) + self.charge_embedding_module = ElectronicEmbedding( + number_of_atom_features, number_of_residual_blocks + ) # initialize representation block self.spookynet_representation_module = SpookyNetRepresentation( @@ -176,7 +181,9 @@ def _model_specific_input_preparation( atomic_embedding = self.atomic_embedding_module(data.atomic_numbers) - charge_embedding = self.charge_embedding_module(atomic_embedding, data.total_charge, num_batch=1) # TODO: what is num_batch? + charge_embedding = self.charge_embedding_module( + atomic_embedding, data.total_charge + ) nnp_input = SpookyNetNeuralNetworkData( pair_indices=pairlist_output.pair_indices, @@ -335,15 +342,30 @@ def __init__(self, number_of_atom_features, max_Z): ], dtype=np.float64) # fmt: on # normalize entries (between 0.0 and 1.0) - self.register_buffer("electron_config", torch.tensor(electron_config / np.max(electron_config, axis=0))) - self.register_parameter("atomic_number_weights", - nn.Parameter(torch.zeros((number_of_atom_features, self.electron_config.shape[1])))) - self.atomic_bias = nn.Embedding(max_Z, number_of_atom_features) + self.register_buffer( + "electron_config", + torch.tensor(electron_config / np.max(electron_config, axis=0)), + ) + self.element_embedding = nn.Embedding(max_Z, number_of_atom_features) + self.register_parameter( + "config_linear", + nn.Parameter( + torch.zeros((number_of_atom_features, self.electron_config.shape[1])) + ), + ) + self.reset_parameters() - def forward(self, atomic_numbers): - return torch.einsum("fe,ne->nf", self.atomic_number_weights, - self.electron_config[atomic_numbers]) + self.atomic_bias(atomic_numbers) + def reset_parameters(self) -> None: + """Initialize parameters.""" + nn.init.zeros_(self.element_embedding.weight) + nn.init.zeros_(self.config_linear) + def forward(self, atomic_numbers): + return torch.einsum( + "fe,ne->nf", + self.config_linear, + self.electron_config[atomic_numbers], + ) + self.element_embedding(atomic_numbers) class SpookyNet(BaseNetwork): @@ -382,7 +404,7 @@ def __init__( super().__init__( dataset_statistic=dataset_statistic, postprocessing_parameter=postprocessing_parameter, - cutoff=_convert(cutoff) + cutoff=_convert(cutoff), ) from modelforge.utils.units import _convert @@ -420,7 +442,7 @@ class ElectronicEmbedding(nn.Module): num_features (int): Dimensions of feature space. num_residual (int): - TODO + TODO: """ def __init__( @@ -428,7 +450,7 @@ def __init__( num_features: int, num_residual: int, ) -> None: - """ Initializes the ElectronicEmbedding class. """ + """Initializes the ElectronicEmbedding class.""" super().__init__() self.linear_q = nn.Linear(num_features, num_features) # charges are duplicated to use separate weights for +/- @@ -442,7 +464,7 @@ def __init__( self.reset_parameters() def reset_parameters(self) -> None: - """ Initialize parameters. """ + """Initialize parameters.""" nn.init.orthogonal_(self.linear_k.weight) nn.init.orthogonal_(self.linear_v.weight) nn.init.orthogonal_(self.linear_q.weight) @@ -452,7 +474,6 @@ def forward( self, x: torch.Tensor, E: torch.Tensor, - num_batch: int, eps: float = 1e-8, ) -> torch.Tensor: """ @@ -462,18 +483,15 @@ def forward( x (FloatTensor [N, num_features]): Atomic feature vectors. """ - batch_seg = torch.zeros(x.size(0), dtype=torch.int64, device=x.device) q = self.linear_q(x) # queries - e = F.relu(torch.stack([E, -E], dim=-1)) - enorm = torch.maximum(e, torch.ones_like(e)) - test = e / enorm - k = self.linear_k(test)[batch_seg] # keys - v = self.linear_v(e)[batch_seg] # values - dot = torch.sum(k * q, dim=-1) / k.shape[-1] ** 0.5 # scaled dot product + e = F.relu(torch.stack([E, -E], dim=-1)) # charges are duplicated to use separate weights for +/- + enorm = torch.clamp(e, min=1) + k = self.linear_k(e / enorm) # keys + v = self.linear_v(e) # values + dot = torch.einsum("nf,nf->n", k, q) / k.shape[-1] ** 0.5 # scaled dot product a = nn.functional.softplus(dot) # unnormalized attention weights - anorm = a.new_zeros(num_batch).index_add_(0, batch_seg, a) - anorm = anorm[batch_seg] - return self.resblock((a / (anorm + eps)).unsqueeze(-1) * v) + a_normalized = a / (a.sum(-1) + eps) # TODO: why is this needed? shouldn't softplus add up to 1? + return self.resblock(torch.einsum("n,nf->nf", a_normalized, v)) class SpookyNetRepresentation(nn.Module): @@ -753,13 +771,13 @@ class SpookyNetLocalInteraction(nn.Module): number_of_radial_basis_functions (int): Number of radial basis functions. num_residual_x (int): - TODO + TODO: num_residual_s (int): - TODO + TODO: num_residual_p (int): - TODO + TODO: num_residual_d (int): - TODO + TODO: num_residual (int): Number of residual blocks to be stacked in sequence. """ @@ -1118,17 +1136,14 @@ def forward( Arguments: x (FloatTensor [N, number_of_atom_features]): Latent atomic feature vectors. - rbf (FloatTensor [P, number_of_radial_basis_functions]): - Values of the radial basis functions for the pairwise distances. + pairlist: + TODO: + filters: + TODO: dir_ij (FloatTensor [P, 3]): Unit vectors pointing from atom i to atom j for all atomic pairs. d_orbital_ij (FloatTensor [P]): Distances between atom i and atom j for all atomic pairs. - idx_i (LongTensor [P]): - Index of atom i for all atomic pairs ij. Each pair must be - specified as both ij and ji. - idx_j (LongTensor [P]): - Same as idx_i, but for atom j. Returns: x (FloatTensor [N, number_of_atom_features]): Updated latent atomic feature vectors. @@ -1151,5 +1166,3 @@ def forward( x_updated = self.residual_post(x_tilde + l + n) del x_tilde return x_updated, self.resblock(x_updated) - - diff --git a/modelforge/tests/test_spookynet.py b/modelforge/tests/test_spookynet.py index a13cf642..31f44c44 100644 --- a/modelforge/tests/test_spookynet.py +++ b/modelforge/tests/test_spookynet.py @@ -41,6 +41,8 @@ def test_forward(): # override default parameters config["potential"]["core_parameter"]["number_of_atom_features"] = 12 config["potential"]["core_parameter"]["number_of_radial_basis_functions"] = 7 + config["potential"]["core_parameter"]["number_of_residual_blocks"] = 1 + config["potential"]["core_parameter"]["number_of_interaction_modules"] = 1 print(f"{config['potential']['core_parameter']}=") @@ -80,6 +82,14 @@ def test_forward(): num_residual_output=config["potential"]["core_parameter"]["number_of_residual_blocks"], ).double() + spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.residual.stack[0].activation2.beta = \ + ref_spookynet.module[0].local_interaction.resblock_d.residual.stack[0].activation2.beta + for name, param in spookynet.named_parameters(): + print(name) + + + + ref_spookynet( model_input.atomic_numbers, model_input.total_charge, From 6ad37127bf850b1069e071aac27d5691da25cf3f Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Thu, 25 Jul 2024 11:24:48 -0700 Subject: [PATCH 72/78] Copy parameters --- modelforge/tests/test_spookynet.py | 280 ++++++++++++++++++++++++++++- 1 file changed, 278 insertions(+), 2 deletions(-) diff --git a/modelforge/tests/test_spookynet.py b/modelforge/tests/test_spookynet.py index 31f44c44..a694bd65 100644 --- a/modelforge/tests/test_spookynet.py +++ b/modelforge/tests/test_spookynet.py @@ -87,8 +87,284 @@ def test_forward(): for name, param in spookynet.named_parameters(): print(name) - - + spookynet.core_module.atomic_embedding_module.element_embedding.weight = ref_spookynet.nuclear_embedding.element_embedding + spookynet.core_module.atomic_embedding_module.config_linear = ref_spookynet.nuclear_embedding.config_linear.weight + spookynet.core_module.charge_embedding_module.linear_q.weight = ref_spookynet.charge_embedding.linear_q.weight + spookynet.core_module.charge_embedding_module.linear_q.bias = ref_spookynet.charge_embedding.linear_q.bias + spookynet.core_module.charge_embedding_module.linear_k.weight = ref_spookynet.charge_embedding.linear_k.weight + spookynet.core_module.charge_embedding_module.linear_v.weight = ref_spookynet.charge_embedding.linear_v.weight + spookynet.core_module.charge_embedding_module.resblock.residual.stack[0].activation1.alpha = \ + ref_spookynet.charge_embedding.resblock.residual.stack[0].activation1.alpha + spookynet.core_module.charge_embedding_module.resblock.residual.stack[0].activation1.beta = \ + ref_spookynet.charge_embedding.resblock.residual.stack[0].activation1.beta + spookynet.core_module.charge_embedding_module.resblock.residual.stack[0].linear1.weight = \ + ref_spookynet.charge_embedding.resblock.residual.stack[0].linear1.weight + spookynet.core_module.charge_embedding_module.resblock.residual.stack[0].activation2.alpha = \ + ref_spookynet.charge_embedding.resblock.residual.stack[0].activation2.alpha + spookynet.core_module.charge_embedding_module.resblock.residual.stack[0].activation2.beta = \ + ref_spookynet.charge_embedding.resblock.residual.stack[0].activation2.beta + spookynet.core_module.charge_embedding_module.resblock.residual.stack[0].linear2.weight = \ + ref_spookynet.charge_embedding.resblock.residual.stack[0].linear2.weight + spookynet.core_module.charge_embedding_module.resblock.activation.alpha = ref_spookynet.charge_embedding.resblock.activation.alpha + spookynet.core_module.charge_embedding_module.resblock.activation.beta = ref_spookynet.charge_embedding.resblock.activation.beta + spookynet.core_module.charge_embedding_module.resblock.linear.weight = ref_spookynet.charge_embedding.resblock.linear.weight + spookynet.core_module.spookynet_representation_module.radial_symmetry_function_module.alpha = ref_spookynet.radial_basis_functions._alpha + spookynet.core_module.interaction_modules[0].local_interaction.radial_s.weight = ref_spookynet.module[ + 0].local_interaction.radial_s.weight + spookynet.core_module.interaction_modules[0].local_interaction.radial_p.weight = ref_spookynet.module[ + 0].local_interaction.radial_p.weight + spookynet.core_module.interaction_modules[0].local_interaction.radial_d.weight = ref_spookynet.module[ + 0].local_interaction.radial_d.weight + spookynet.core_module.interaction_modules[0].local_interaction.resblock_x.residual.stack[0].activation1.alpha = \ + ref_spookynet.module[0].local_interaction.resblock_x.residual.stack[0].activation1.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resblock_x.residual.stack[0].activation1.beta = \ + ref_spookynet.module[0].local_interaction.resblock_x.residual.stack[0].activation1.beta + spookynet.core_module.interaction_modules[0].local_interaction.resblock_x.residual.stack[0].linear1.weight = \ + ref_spookynet.module[0].local_interaction.resblock_x.residual.stack[0].linear1.weight + spookynet.core_module.interaction_modules[0].local_interaction.resblock_x.residual.stack[0].linear1.bias = \ + ref_spookynet.module[0].local_interaction.resblock_x.residual.stack[0].linear1.bias + spookynet.core_module.interaction_modules[0].local_interaction.resblock_x.residual.stack[0].activation2.alpha = \ + ref_spookynet.module[0].local_interaction.resblock_x.residual.stack[0].activation2.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resblock_x.residual.stack[0].activation2.beta = \ + ref_spookynet.module[0].local_interaction.resblock_x.residual.stack[0].activation2.beta + spookynet.core_module.interaction_modules[0].local_interaction.resblock_x.residual.stack[0].linear2.weight = \ + ref_spookynet.module[0].local_interaction.resblock_x.residual.stack[0].linear2.weight + spookynet.core_module.interaction_modules[0].local_interaction.resblock_x.residual.stack[0].linear2.bias = \ + ref_spookynet.module[0].local_interaction.resblock_x.residual.stack[0].linear2.bias + spookynet.core_module.interaction_modules[0].local_interaction.resblock_x.activation.alpha = ref_spookynet.module[ + 0].local_interaction.resblock_x.activation.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resblock_x.activation.beta = ref_spookynet.module[ + 0].local_interaction.resblock_x.activation.beta + spookynet.core_module.interaction_modules[0].local_interaction.resblock_x.linear.weight = ref_spookynet.module[ + 0].local_interaction.resblock_x.linear.weight + spookynet.core_module.interaction_modules[0].local_interaction.resblock_x.linear.bias = ref_spookynet.module[ + 0].local_interaction.resblock_x.linear.bias + spookynet.core_module.interaction_modules[0].local_interaction.resblock_s.residual.stack[0].activation1.alpha = \ + ref_spookynet.module[0].local_interaction.resblock_s.residual.stack[0].activation1.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resblock_s.residual.stack[0].activation1.beta = \ + ref_spookynet.module[0].local_interaction.resblock_s.residual.stack[0].activation1.beta + spookynet.core_module.interaction_modules[0].local_interaction.resblock_s.residual.stack[0].linear1.weight = \ + ref_spookynet.module[0].local_interaction.resblock_s.residual.stack[0].linear1.weight + spookynet.core_module.interaction_modules[0].local_interaction.resblock_s.residual.stack[0].linear1.bias = \ + ref_spookynet.module[0].local_interaction.resblock_s.residual.stack[0].linear1.bias + spookynet.core_module.interaction_modules[0].local_interaction.resblock_s.residual.stack[0].activation2.alpha = \ + ref_spookynet.module[0].local_interaction.resblock_s.residual.stack[0].activation2.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resblock_s.residual.stack[0].activation2.beta = \ + ref_spookynet.module[0].local_interaction.resblock_s.residual.stack[0].activation2.beta + spookynet.core_module.interaction_modules[0].local_interaction.resblock_s.residual.stack[0].linear2.weight = \ + ref_spookynet.module[0].local_interaction.resblock_s.residual.stack[0].linear2.weight + spookynet.core_module.interaction_modules[0].local_interaction.resblock_s.residual.stack[0].linear2.bias = \ + ref_spookynet.module[0].local_interaction.resblock_s.residual.stack[0].linear2.bias + spookynet.core_module.interaction_modules[0].local_interaction.resblock_s.activation.alpha = ref_spookynet.module[ + 0].local_interaction.resblock_s.activation.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resblock_s.activation.beta = ref_spookynet.module[ + 0].local_interaction.resblock_s.activation.beta + spookynet.core_module.interaction_modules[0].local_interaction.resblock_s.linear.weight = ref_spookynet.module[ + 0].local_interaction.resblock_s.linear.weight + spookynet.core_module.interaction_modules[0].local_interaction.resblock_s.linear.bias = ref_spookynet.module[ + 0].local_interaction.resblock_s.linear.bias + spookynet.core_module.interaction_modules[0].local_interaction.resblock_p.residual.stack[0].activation1.alpha = \ + ref_spookynet.module[0].local_interaction.resblock_p.residual.stack[0].activation1.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resblock_p.residual.stack[0].activation1.beta = \ + ref_spookynet.module[0].local_interaction.resblock_p.residual.stack[0].activation1.beta + spookynet.core_module.interaction_modules[0].local_interaction.resblock_p.residual.stack[0].linear1.weight = \ + ref_spookynet.module[0].local_interaction.resblock_p.residual.stack[0].linear1.weight + spookynet.core_module.interaction_modules[0].local_interaction.resblock_p.residual.stack[0].linear1.bias = \ + ref_spookynet.module[0].local_interaction.resblock_p.residual.stack[0].linear1.bias + spookynet.core_module.interaction_modules[0].local_interaction.resblock_p.residual.stack[0].activation2.alpha = \ + ref_spookynet.module[0].local_interaction.resblock_p.residual.stack[0].activation2.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resblock_p.residual.stack[0].activation2.beta = \ + ref_spookynet.module[0].local_interaction.resblock_p.residual.stack[0].activation2.beta + spookynet.core_module.interaction_modules[0].local_interaction.resblock_p.residual.stack[0].linear2.weight = \ + ref_spookynet.module[0].local_interaction.resblock_p.residual.stack[0].linear2.weight + spookynet.core_module.interaction_modules[0].local_interaction.resblock_p.residual.stack[0].linear2.bias = \ + ref_spookynet.module[0].local_interaction.resblock_p.residual.stack[0].linear2.bias + spookynet.core_module.interaction_modules[0].local_interaction.resblock_p.activation.alpha = ref_spookynet.module[ + 0].local_interaction.resblock_p.activation.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resblock_p.activation.beta = ref_spookynet.module[ + 0].local_interaction.resblock_p.activation.beta + spookynet.core_module.interaction_modules[0].local_interaction.resblock_p.linear.weight = ref_spookynet.module[ + 0].local_interaction.resblock_p.linear.weight + spookynet.core_module.interaction_modules[0].local_interaction.resblock_p.linear.bias = ref_spookynet.module[ + 0].local_interaction.resblock_p.linear.bias + spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.residual.stack[0].activation1.alpha = \ + ref_spookynet.module[0].local_interaction.resblock_d.residual.stack[0].activation1.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.residual.stack[0].activation1.beta = \ + ref_spookynet.module[0].local_interaction.resblock_d.residual.stack[0].activation1.beta + spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.residual.stack[0].linear1.weight = \ + ref_spookynet.module[0].local_interaction.resblock_d.residual.stack[0].linear1.weight + spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.residual.stack[0].linear1.bias = \ + ref_spookynet.module[0].local_interaction.resblock_d.residual.stack[0].linear1.bias + spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.residual.stack[0].activation2.alpha = \ + ref_spookynet.module[0].local_interaction.resblock_d.residual.stack[0].activation2.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.residual.stack[0].activation2.beta = \ + ref_spookynet.module[0].local_interaction.resblock_d.residual.stack[0].activation2.beta + spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.residual.stack[0].linear2.weight = \ + ref_spookynet.module[0].local_interaction.resblock_d.residual.stack[0].linear2.weight + spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.residual.stack[0].linear2.bias = \ + ref_spookynet.module[0].local_interaction.resblock_d.residual.stack[0].linear2.bias + spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.activation.alpha = ref_spookynet.module[ + 0].local_interaction.resblock_d.activation.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.activation.beta = ref_spookynet.module[ + 0].local_interaction.resblock_d.activation.beta + spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.linear.weight = ref_spookynet.module[ + 0].local_interaction.resblock_d.linear.weight + spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.linear.bias = ref_spookynet.module[ + 0].local_interaction.resblock_d.linear.bias + spookynet.core_module.interaction_modules[0].local_interaction.projection_p.weight = ref_spookynet.module[ + 0].local_interaction.projection_p.weight + spookynet.core_module.interaction_modules[0].local_interaction.projection_d.weight = ref_spookynet.module[ + 0].local_interaction.projection_d.weight + spookynet.core_module.interaction_modules[0].local_interaction.resblock.residual.stack[0].activation1.alpha = \ + ref_spookynet.module[0].local_interaction.resblock.residual.stack[0].activation1.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resblock.residual.stack[0].activation1.beta = \ + ref_spookynet.module[0].local_interaction.resblock.residual.stack[0].activation1.beta + spookynet.core_module.interaction_modules[0].local_interaction.resblock.residual.stack[0].linear1.weight = \ + ref_spookynet.module[0].local_interaction.resblock.residual.stack[0].linear1.weight + spookynet.core_module.interaction_modules[0].local_interaction.resblock.residual.stack[0].linear1.bias = \ + ref_spookynet.module[0].local_interaction.resblock.residual.stack[0].linear1.bias + spookynet.core_module.interaction_modules[0].local_interaction.resblock.residual.stack[0].activation2.alpha = \ + ref_spookynet.module[0].local_interaction.resblock.residual.stack[0].activation2.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resblock.residual.stack[0].activation2.beta = \ + ref_spookynet.module[0].local_interaction.resblock.residual.stack[0].activation2.beta + spookynet.core_module.interaction_modules[0].local_interaction.resblock.residual.stack[0].linear2.weight = \ + ref_spookynet.module[0].local_interaction.resblock.residual.stack[0].linear2.weight + spookynet.core_module.interaction_modules[0].local_interaction.resblock.residual.stack[0].linear2.bias = \ + ref_spookynet.module[0].local_interaction.resblock.residual.stack[0].linear2.bias + spookynet.core_module.interaction_modules[0].local_interaction.resblock.activation.alpha = ref_spookynet.module[ + 0].local_interaction.resblock.activation.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resblock.activation.beta = ref_spookynet.module[ + 0].local_interaction.resblock.activation.beta + spookynet.core_module.interaction_modules[0].local_interaction.resblock.linear.weight = ref_spookynet.module[ + 0].local_interaction.resblock.linear.weight + spookynet.core_module.interaction_modules[0].local_interaction.resblock.linear.bias = ref_spookynet.module[ + 0].local_interaction.resblock.linear.bias + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_q.residual.stack[0].activation1.alpha = \ + ref_spookynet.module[0].nonlocal_interaction.resblock_q.residual.stack[0].activation1.alpha + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_q.residual.stack[0].activation1.beta = \ + ref_spookynet.module[0].nonlocal_interaction.resblock_q.residual.stack[0].activation1.beta + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_q.residual.stack[0].linear1.weight = \ + ref_spookynet.module[0].nonlocal_interaction.resblock_q.residual.stack[0].linear1.weight + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_q.residual.stack[0].linear1.bias = \ + ref_spookynet.module[0].nonlocal_interaction.resblock_q.residual.stack[0].linear1.bias + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_q.residual.stack[0].activation2.alpha = \ + ref_spookynet.module[0].nonlocal_interaction.resblock_q.residual.stack[0].activation2.alpha + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_q.residual.stack[0].activation2.beta = \ + ref_spookynet.module[0].nonlocal_interaction.resblock_q.residual.stack[0].activation2.beta + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_q.residual.stack[0].linear2.weight = \ + ref_spookynet.module[0].nonlocal_interaction.resblock_q.residual.stack[0].linear2.weight + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_q.residual.stack[0].linear2.bias = \ + ref_spookynet.module[0].nonlocal_interaction.resblock_q.residual.stack[0].linear2.bias + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_q.activation.alpha = \ + ref_spookynet.module[0].nonlocal_interaction.resblock_q.activation.alpha + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_q.activation.beta = ref_spookynet.module[ + 0].nonlocal_interaction.resblock_q.activation.beta + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_q.linear.weight = ref_spookynet.module[ + 0].nonlocal_interaction.resblock_q.linear.weight + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_q.linear.bias = ref_spookynet.module[ + 0].nonlocal_interaction.resblock_q.linear.bias + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_k.residual.stack[0].activation1.alpha = \ + ref_spookynet.module[0].nonlocal_interaction.resblock_k.residual.stack[0].activation1.alpha + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_k.residual.stack[0].activation1.beta = \ + ref_spookynet.module[0].nonlocal_interaction.resblock_k.residual.stack[0].activation1.beta + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_k.residual.stack[0].linear1.weight = \ + ref_spookynet.module[0].nonlocal_interaction.resblock_k.residual.stack[0].linear1.weight + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_k.residual.stack[0].linear1.bias = \ + ref_spookynet.module[0].nonlocal_interaction.resblock_k.residual.stack[0].linear1.bias + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_k.residual.stack[0].activation2.alpha = \ + ref_spookynet.module[0].nonlocal_interaction.resblock_k.residual.stack[0].activation2.alpha + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_k.residual.stack[0].activation2.beta = \ + ref_spookynet.module[0].nonlocal_interaction.resblock_k.residual.stack[0].activation2.beta + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_k.residual.stack[0].linear2.weight = \ + ref_spookynet.module[0].nonlocal_interaction.resblock_k.residual.stack[0].linear2.weight + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_k.residual.stack[0].linear2.bias = \ + ref_spookynet.module[0].nonlocal_interaction.resblock_k.residual.stack[0].linear2.bias + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_k.activation.alpha = \ + ref_spookynet.module[0].nonlocal_interaction.resblock_k.activation.alpha + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_k.activation.beta = ref_spookynet.module[ + 0].nonlocal_interaction.resblock_k.activation.beta + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_k.linear.weight = ref_spookynet.module[ + 0].nonlocal_interaction.resblock_k.linear.weight + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_k.linear.bias = ref_spookynet.module[ + 0].nonlocal_interaction.resblock_k.linear.bias + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_v.residual.stack[0].activation1.alpha = \ + ref_spookynet.module[0].nonlocal_interaction.resblock_v.residual.stack[0].activation1.alpha + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_v.residual.stack[0].activation1.beta = \ + ref_spookynet.module[0].nonlocal_interaction.resblock_v.residual.stack[0].activation1.beta + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_v.residual.stack[0].linear1.weight = \ + ref_spookynet.module[0].nonlocal_interaction.resblock_v.residual.stack[0].linear1.weight + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_v.residual.stack[0].linear1.bias = \ + ref_spookynet.module[0].nonlocal_interaction.resblock_v.residual.stack[0].linear1.bias + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_v.residual.stack[0].activation2.alpha = \ + ref_spookynet.module[0].nonlocal_interaction.resblock_v.residual.stack[0].activation2.alpha + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_v.residual.stack[0].activation2.beta = \ + ref_spookynet.module[0].nonlocal_interaction.resblock_v.residual.stack[0].activation2.beta + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_v.residual.stack[0].linear2.weight = \ + ref_spookynet.module[0].nonlocal_interaction.resblock_v.residual.stack[0].linear2.weight + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_v.residual.stack[0].linear2.bias = \ + ref_spookynet.module[0].nonlocal_interaction.resblock_v.residual.stack[0].linear2.bias + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_v.activation.alpha = \ + ref_spookynet.module[0].nonlocal_interaction.resblock_v.activation.alpha + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_v.activation.beta = ref_spookynet.module[ + 0].nonlocal_interaction.resblock_v.activation.beta + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_v.linear.weight = ref_spookynet.module[ + 0].nonlocal_interaction.resblock_v.linear.weight + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_v.linear.bias = ref_spookynet.module[ + 0].nonlocal_interaction.resblock_v.linear.bias + spookynet.core_module.interaction_modules[0].residual_pre.stack[0].activation1.alpha = \ + ref_spookynet.module[0].residual_pre.stack[0].activation1.alpha + spookynet.core_module.interaction_modules[0].residual_pre.stack[0].activation1.beta = \ + ref_spookynet.module[0].residual_pre.stack[0].activation1.beta + spookynet.core_module.interaction_modules[0].residual_pre.stack[0].linear1.weight = \ + ref_spookynet.module[0].residual_pre.stack[0].linear1.weight + spookynet.core_module.interaction_modules[0].residual_pre.stack[0].linear1.bias = \ + ref_spookynet.module[0].residual_pre.stack[0].linear1.bias + spookynet.core_module.interaction_modules[0].residual_pre.stack[0].activation2.alpha = \ + ref_spookynet.module[0].residual_pre.stack[0].activation2.alpha + spookynet.core_module.interaction_modules[0].residual_pre.stack[0].activation2.beta = \ + ref_spookynet.module[0].residual_pre.stack[0].activation2.beta + spookynet.core_module.interaction_modules[0].residual_pre.stack[0].linear2.weight = \ + ref_spookynet.module[0].residual_pre.stack[0].linear2.weight + spookynet.core_module.interaction_modules[0].residual_pre.stack[0].linear2.bias = \ + ref_spookynet.module[0].residual_pre.stack[0].linear2.bias + spookynet.core_module.interaction_modules[0].residual_post.stack[0].activation1.alpha = \ + ref_spookynet.module[0].residual_post.stack[0].activation1.alpha + spookynet.core_module.interaction_modules[0].residual_post.stack[0].activation1.beta = \ + ref_spookynet.module[0].residual_post.stack[0].activation1.beta + spookynet.core_module.interaction_modules[0].residual_post.stack[0].linear1.weight = \ + ref_spookynet.module[0].residual_post.stack[0].linear1.weight + spookynet.core_module.interaction_modules[0].residual_post.stack[0].linear1.bias = \ + ref_spookynet.module[0].residual_post.stack[0].linear1.bias + spookynet.core_module.interaction_modules[0].residual_post.stack[0].activation2.alpha = \ + ref_spookynet.module[0].residual_post.stack[0].activation2.alpha + spookynet.core_module.interaction_modules[0].residual_post.stack[0].activation2.beta = \ + ref_spookynet.module[0].residual_post.stack[0].activation2.beta + spookynet.core_module.interaction_modules[0].residual_post.stack[0].linear2.weight = \ + ref_spookynet.module[0].residual_post.stack[0].linear2.weight + spookynet.core_module.interaction_modules[0].residual_post.stack[0].linear2.bias = \ + ref_spookynet.module[0].residual_post.stack[0].linear2.bias + spookynet.core_module.interaction_modules[0].resblock.residual.stack[0].activation1.alpha = \ + ref_spookynet.module[0].resblock.residual.stack[0].activation1.alpha + spookynet.core_module.interaction_modules[0].resblock.residual.stack[0].activation1.beta = \ + ref_spookynet.module[0].resblock.residual.stack[0].activation1.beta + spookynet.core_module.interaction_modules[0].resblock.residual.stack[0].linear1.weight = \ + ref_spookynet.module[0].resblock.residual.stack[0].linear1.weight + spookynet.core_module.interaction_modules[0].resblock.residual.stack[0].linear1.bias = \ + ref_spookynet.module[0].resblock.residual.stack[0].linear1.bias + spookynet.core_module.interaction_modules[0].resblock.residual.stack[0].activation2.alpha = \ + ref_spookynet.module[0].resblock.residual.stack[0].activation2.alpha + spookynet.core_module.interaction_modules[0].resblock.residual.stack[0].activation2.beta = \ + ref_spookynet.module[0].resblock.residual.stack[0].activation2.beta + spookynet.core_module.interaction_modules[0].resblock.residual.stack[0].linear2.weight = \ + ref_spookynet.module[0].resblock.residual.stack[0].linear2.weight + spookynet.core_module.interaction_modules[0].resblock.residual.stack[0].linear2.bias = \ + ref_spookynet.module[0].resblock.residual.stack[0].linear2.bias + spookynet.core_module.interaction_modules[0].resblock.activation.alpha = ref_spookynet.module[ + 0].resblock.activation.alpha + spookynet.core_module.interaction_modules[0].resblock.activation.beta = ref_spookynet.module[ + 0].resblock.activation.beta + spookynet.core_module.interaction_modules[0].resblock.linear.weight = ref_spookynet.module[0].resblock.linear.weight + spookynet.core_module.interaction_modules[0].resblock.linear.bias = ref_spookynet.module[0].resblock.linear.bias ref_spookynet( model_input.atomic_numbers, From 0539ed37bef76f3ecbc166189be4017b99ecc6c8 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Thu, 25 Jul 2024 11:37:42 -0700 Subject: [PATCH 73/78] Add line breaks between different blocks when copying parameters --- modelforge/tests/test_spookynet.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/modelforge/tests/test_spookynet.py b/modelforge/tests/test_spookynet.py index a694bd65..a47a18c7 100644 --- a/modelforge/tests/test_spookynet.py +++ b/modelforge/tests/test_spookynet.py @@ -80,6 +80,9 @@ def test_forward(): num_residual_nonlocal_v=config["potential"]["core_parameter"]["number_of_residual_blocks"], num_residual_post=config["potential"]["core_parameter"]["number_of_residual_blocks"], num_residual_output=config["potential"]["core_parameter"]["number_of_residual_blocks"], + use_zbl_repulsion=False, + use_electrostatics=False, + use_d4_dispersion=False, ).double() spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.residual.stack[0].activation2.beta = \ @@ -89,6 +92,7 @@ def test_forward(): spookynet.core_module.atomic_embedding_module.element_embedding.weight = ref_spookynet.nuclear_embedding.element_embedding spookynet.core_module.atomic_embedding_module.config_linear = ref_spookynet.nuclear_embedding.config_linear.weight + spookynet.core_module.charge_embedding_module.linear_q.weight = ref_spookynet.charge_embedding.linear_q.weight spookynet.core_module.charge_embedding_module.linear_q.bias = ref_spookynet.charge_embedding.linear_q.bias spookynet.core_module.charge_embedding_module.linear_k.weight = ref_spookynet.charge_embedding.linear_k.weight @@ -108,7 +112,9 @@ def test_forward(): spookynet.core_module.charge_embedding_module.resblock.activation.alpha = ref_spookynet.charge_embedding.resblock.activation.alpha spookynet.core_module.charge_embedding_module.resblock.activation.beta = ref_spookynet.charge_embedding.resblock.activation.beta spookynet.core_module.charge_embedding_module.resblock.linear.weight = ref_spookynet.charge_embedding.resblock.linear.weight + spookynet.core_module.spookynet_representation_module.radial_symmetry_function_module.alpha = ref_spookynet.radial_basis_functions._alpha + spookynet.core_module.interaction_modules[0].local_interaction.radial_s.weight = ref_spookynet.module[ 0].local_interaction.radial_s.weight spookynet.core_module.interaction_modules[0].local_interaction.radial_p.weight = ref_spookynet.module[ @@ -239,6 +245,7 @@ def test_forward(): 0].local_interaction.resblock.linear.weight spookynet.core_module.interaction_modules[0].local_interaction.resblock.linear.bias = ref_spookynet.module[ 0].local_interaction.resblock.linear.bias + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_q.residual.stack[0].activation1.alpha = \ ref_spookynet.module[0].nonlocal_interaction.resblock_q.residual.stack[0].activation1.alpha spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_q.residual.stack[0].activation1.beta = \ @@ -311,6 +318,7 @@ def test_forward(): 0].nonlocal_interaction.resblock_v.linear.weight spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_v.linear.bias = ref_spookynet.module[ 0].nonlocal_interaction.resblock_v.linear.bias + spookynet.core_module.interaction_modules[0].residual_pre.stack[0].activation1.alpha = \ ref_spookynet.module[0].residual_pre.stack[0].activation1.alpha spookynet.core_module.interaction_modules[0].residual_pre.stack[0].activation1.beta = \ @@ -327,6 +335,7 @@ def test_forward(): ref_spookynet.module[0].residual_pre.stack[0].linear2.weight spookynet.core_module.interaction_modules[0].residual_pre.stack[0].linear2.bias = \ ref_spookynet.module[0].residual_pre.stack[0].linear2.bias + spookynet.core_module.interaction_modules[0].residual_post.stack[0].activation1.alpha = \ ref_spookynet.module[0].residual_post.stack[0].activation1.alpha spookynet.core_module.interaction_modules[0].residual_post.stack[0].activation1.beta = \ @@ -343,6 +352,7 @@ def test_forward(): ref_spookynet.module[0].residual_post.stack[0].linear2.weight spookynet.core_module.interaction_modules[0].residual_post.stack[0].linear2.bias = \ ref_spookynet.module[0].residual_post.stack[0].linear2.bias + spookynet.core_module.interaction_modules[0].resblock.residual.stack[0].activation1.alpha = \ ref_spookynet.module[0].resblock.residual.stack[0].activation1.alpha spookynet.core_module.interaction_modules[0].resblock.residual.stack[0].activation1.beta = \ From b61358681b3c9a613d716c6ed3c944972a3e1329 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Fri, 26 Jul 2024 11:25:12 -0700 Subject: [PATCH 74/78] Changes to tests and model --- modelforge/potential/spookynet.py | 50 +++++++++++-------- .../data/potential_defaults/spookynet.toml | 8 +-- modelforge/tests/test_spookynet.py | 15 +++++- 3 files changed, 47 insertions(+), 26 deletions(-) diff --git a/modelforge/potential/spookynet.py b/modelforge/potential/spookynet.py index 9a1567b8..c428c237 100644 --- a/modelforge/potential/spookynet.py +++ b/modelforge/potential/spookynet.py @@ -95,12 +95,12 @@ class SpookyNetNeuralNetworkData(NeuralNetworkData): class SpookyNetCore(CoreNetwork): def __init__( self, - max_Z: int = 87, # need to update electron_config if we want to use higher atomic numbers - number_of_atom_features: int = 64, - number_of_radial_basis_functions: int = 20, - number_of_interaction_modules: int = 3, - number_of_residual_blocks: int = 7, - cutoff: unit.Quantity = 5.0 * unit.angstrom, + max_Z, + cutoff: unit.Quantity, + number_of_atom_features, + number_of_radial_basis_functions, + number_of_interaction_modules, + number_of_residual_blocks, ) -> None: """ Initialize the SpookyNet class. @@ -123,9 +123,6 @@ def __init__( self.number_of_atom_features = number_of_atom_features self.number_of_radial_basis_functions = number_of_radial_basis_functions - # embedding - from modelforge.potential.utils import Embedding - assert max_Z <= 87 self.atomic_embedding_module = SpookyNetAtomicEmbedding( number_of_atom_features, max_Z @@ -139,6 +136,7 @@ def __init__( cutoff, number_of_radial_basis_functions ) + ic(number_of_interaction_modules) # Intialize interaction blocks self.interaction_modules = nn.ModuleList( [ @@ -162,18 +160,18 @@ def __init__( ) # final output layer - self.energy_layer = nn.Sequential( - Dense( - number_of_atom_features, - number_of_atom_features, - activation=ShiftedSoftplus(), - ), + self.energy_and_charge_readout = nn.Sequential( Dense( number_of_atom_features, - 1, + 2, + activation=None, + bias=False, ), ) + # learnable shift and bias that is applied per-element to ech atomic energy + self.atomic_shift = nn.Parameter(torch.zeros(max_Z, 2)) + def _model_specific_input_preparation( self, data: "NNPInput", pairlist_output: "PairListOutputs" ) -> SpookyNetNeuralNetworkData: @@ -232,14 +230,23 @@ def compute_properties( ) f += y # accumulate module output to features - E_i = self.energy_layer(x).squeeze(1) + per_atom_energy_and_charge = self.energy_and_charge_readout(x) - return { - "E_i": E_i, - "q": x, + per_atom_energy_and_charge_shifted = self.atomic_shift[data.atomic_numbers] + per_atom_energy_and_charge + + E_i = per_atom_energy_and_charge_shifted[:, 0] # shape(nr_of_atoms, 1) + q_i = per_atom_energy_and_charge_shifted[:, 1] # shape(nr_of_atoms, 1) + + output = { + "per_atom_energy": E_i.contiguous(), # reshape memory mapping for JAX/dlpack + "q_i": q_i.contiguous(), "atomic_subsystem_indices": data.atomic_subsystem_indices, + "atomic_numbers": data.atomic_numbers, } + return output + + from .models import InputPreparation, NNPInput, BaseNetwork @@ -410,6 +417,7 @@ def __init__( self.core_module = SpookyNetCore( max_Z=max_Z, + cutoff=_convert(cutoff), number_of_atom_features=number_of_atom_features, number_of_radial_basis_functions=number_of_radial_basis_functions, number_of_interaction_modules=number_of_interaction_modules, @@ -879,6 +887,8 @@ def forward( pa, pb = torch.split(self.projection_p(p), p.shape[-1], dim=-1) da, db = torch.split(self.projection_d(d), d.shape[-1], dim=-1) # n: number_of_atoms_in_system, x: 3 (geometry axis), f: number_of_atom_features + ic(pa.shape) + ic(f_ij_after_cutoff.shape) return self.resblock( s + torch.einsum("nxf,nxf->nf", pa, pb) diff --git a/modelforge/tests/data/potential_defaults/spookynet.toml b/modelforge/tests/data/potential_defaults/spookynet.toml index 0c7ba898..a4e64269 100644 --- a/modelforge/tests/data/potential_defaults/spookynet.toml +++ b/modelforge/tests/data/potential_defaults/spookynet.toml @@ -3,11 +3,11 @@ model_name = "SpookyNet" [potential.core_parameter] max_Z = 87 -number_of_atom_features = 32 -number_of_radial_basis_functions = 20 -cutoff = "5.0 angstrom" +number_of_atom_features = 64 +number_of_radial_basis_functions = 16 +cutoff = "5.291772105638412 angstrom" # 10 a0 number_of_interaction_modules = 3 -number_of_residual_blocks = 7 +number_of_residual_blocks = 1 [potential.postprocessing_parameter] [potential.postprocessing_parameter.per_atom_energy] diff --git a/modelforge/tests/test_spookynet.py b/modelforge/tests/test_spookynet.py index a47a18c7..d98fea6f 100644 --- a/modelforge/tests/test_spookynet.py +++ b/modelforge/tests/test_spookynet.py @@ -7,6 +7,8 @@ from icecream import ic import pytest +from modelforge.utils.units import _convert + def test_init(): """Test initialization of the SpookyNet model.""" @@ -43,6 +45,7 @@ def test_forward(): config["potential"]["core_parameter"]["number_of_radial_basis_functions"] = 7 config["potential"]["core_parameter"]["number_of_residual_blocks"] = 1 config["potential"]["core_parameter"]["number_of_interaction_modules"] = 1 + config["potential"]["core_parameter"]["cutoff"] = "2.4 meter" print(f"{config['potential']['core_parameter']}=") @@ -83,6 +86,7 @@ def test_forward(): use_zbl_repulsion=False, use_electrostatics=False, use_d4_dispersion=False, + cutoff=_convert(config["potential"]["core_parameter"]["cutoff"]).m_as(unit.angstrom), ).double() spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.residual.stack[0].activation2.beta = \ @@ -90,6 +94,10 @@ def test_forward(): for name, param in spookynet.named_parameters(): print(name) + for param in spookynet.parameters(): + torch.nn.init.normal_(param, 5.0, 3.0) + + spookynet.core_module.atomic_shift = ref_spookynet.element_bias spookynet.core_module.atomic_embedding_module.element_embedding.weight = ref_spookynet.nuclear_embedding.element_embedding spookynet.core_module.atomic_embedding_module.config_linear = ref_spookynet.nuclear_embedding.config_linear.weight @@ -376,14 +384,17 @@ def test_forward(): spookynet.core_module.interaction_modules[0].resblock.linear.weight = ref_spookynet.module[0].resblock.linear.weight spookynet.core_module.interaction_modules[0].resblock.linear.bias = ref_spookynet.module[0].resblock.linear.bias - ref_spookynet( + reference_calculated_results = ref_spookynet( model_input.atomic_numbers, model_input.total_charge, - model_input.positions, + (model_input.positions * unit.nanometer).m_as(unit.angstrom), pairlist_output.pair_indices[0], pairlist_output.pair_indices[1], ) + ic(calculated_results.keys()) + ic(reference_calculated_results) + def test_spookynet_forward(single_batch_with_batchsize_64, model_parameter): """ From 10c9914807dcbed7a38db1578ebc57fa6c725470 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Tue, 30 Jul 2024 10:33:45 -0700 Subject: [PATCH 75/78] More changes --- modelforge/potential/spookynet.py | 29 ++++++++++++++++++++++------- modelforge/tests/test_spookynet.py | 29 +++++++++++++++++++++++------ 2 files changed, 45 insertions(+), 13 deletions(-) diff --git a/modelforge/potential/spookynet.py b/modelforge/potential/spookynet.py index c428c237..617d4a88 100644 --- a/modelforge/potential/spookynet.py +++ b/modelforge/potential/spookynet.py @@ -180,7 +180,7 @@ def _model_specific_input_preparation( atomic_embedding = self.atomic_embedding_module(data.atomic_numbers) charge_embedding = self.charge_embedding_module( - atomic_embedding, data.total_charge + atomic_embedding, data.total_charge, data.atomic_subsystem_indices ) nnp_input = SpookyNetNeuralNetworkData( @@ -482,6 +482,7 @@ def forward( self, x: torch.Tensor, E: torch.Tensor, + atomic_subsystem_indices: torch.Tensor, eps: float = 1e-8, ) -> torch.Tensor: """ @@ -491,12 +492,16 @@ def forward( x (FloatTensor [N, num_features]): Atomic feature vectors. """ + ic(E.shape) q = self.linear_q(x) # queries + ic(q.shape) e = F.relu(torch.stack([E, -E], dim=-1)) # charges are duplicated to use separate weights for +/- - enorm = torch.clamp(e, min=1) - k = self.linear_k(e / enorm) # keys - v = self.linear_v(e) # values - dot = torch.einsum("nf,nf->n", k, q) / k.shape[-1] ** 0.5 # scaled dot product + ic(e.shape) + k = self.linear_k(e / torch.clamp(e, min=1))[atomic_subsystem_indices] # keys + ic(k.shape) + v = self.linear_v(e)[atomic_subsystem_indices] # values + ic(v.shape) + dot = torch.einsum("nf,nf->n", k, q) / math.sqrt(k.shape[-1]) # scaled dot product a = nn.functional.softplus(dot) # unnormalized attention weights a_normalized = a / (a.sum(-1) + eps) # TODO: why is this needed? shouldn't softplus add up to 1? return self.resblock(torch.einsum("n,nf->nf", a_normalized, v)) @@ -985,8 +990,18 @@ def forward( """ Q = self._phi(Q, True) # random projection of Q K = self._phi(K, False) # random projection of K - norm = Q @ torch.sum(K, 0, keepdim=True).T + eps - return (Q @ (K.T @ V)) / norm + ic(Q.shape) + ic(K.shape) + ic(V.shape) + ic(torch.sum(K, 0, keepdim=True).T) + norm = torch.einsum("nf,f->n", Q, torch.sum(K, dim=0)) + eps + ic(norm.shape) + # n: number of atoms, F: dim_qk, f: value features + rv = torch.einsum("nF,nF,nf->nf", Q, K, V) / norm.unsqueeze(-1) + ic((K.T @ V).shape) + ic((Q @ (K.T @ V)).shape) + ic(rv.shape) + return rv class SpookyNetNonlocalInteraction(nn.Module): diff --git a/modelforge/tests/test_spookynet.py b/modelforge/tests/test_spookynet.py index d98fea6f..4160f73d 100644 --- a/modelforge/tests/test_spookynet.py +++ b/modelforge/tests/test_spookynet.py @@ -1,3 +1,4 @@ +from modelforge.dataset.dataset import NNPInput from modelforge.potential.spookynet import SpookyNet from spookynet import SpookyNet as RefSpookyNet from modelforge.tests.precalculated_values import ( @@ -41,7 +42,7 @@ def test_forward(): config = load_configs(f"spookynet", "qm9") # override default parameters - config["potential"]["core_parameter"]["number_of_atom_features"] = 12 + config["potential"]["core_parameter"]["number_of_atom_features"] = 11 config["potential"]["core_parameter"]["number_of_radial_basis_functions"] = 7 config["potential"]["core_parameter"]["number_of_residual_blocks"] = 1 config["potential"]["core_parameter"]["number_of_interaction_modules"] = 1 @@ -57,14 +58,25 @@ def test_forward(): postprocessing_parameter=config["potential"]["postprocessing_parameter"], ).double() - input = setup_single_methane_input() - model_input = input["modelforge_methane_input"] + single_methane = setup_single_methane_input()["modelforge_methane_input"] + model_input = NNPInput( + atomic_numbers=torch.cat([single_methane.atomic_numbers] * 2, dim=0), + positions=torch.cat([single_methane.positions] * 2, dim=0), + atomic_subsystem_indices=torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]), + total_charge=torch.cat([single_methane.total_charge] * 2, dim=0), + ) + print(f"{torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], dtype=torch.int64).dtype=}") + print(f"test: {model_input.atomic_subsystem_indices.dtype=}") + ic(model_input) model_input.positions = model_input.positions.double() model_input.total_charge = model_input.total_charge.double() + print(f"test: {model_input.atomic_subsystem_indices.dtype=}") spookynet.input_preparation._input_checks(model_input) + print(f"test: {model_input.atomic_subsystem_indices.dtype=}") pairlist_output = spookynet.input_preparation.prepare_inputs(model_input) + print(f"test: {model_input.atomic_subsystem_indices.dtype=}") calculated_results = spookynet.core_module.forward(model_input, pairlist_output) ref_spookynet = RefSpookyNet( @@ -91,10 +103,8 @@ def test_forward(): spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.residual.stack[0].activation2.beta = \ ref_spookynet.module[0].local_interaction.resblock_d.residual.stack[0].activation2.beta - for name, param in spookynet.named_parameters(): - print(name) - for param in spookynet.parameters(): + for param in ref_spookynet.parameters(): torch.nn.init.normal_(param, 5.0, 3.0) spookynet.core_module.atomic_shift = ref_spookynet.element_bias @@ -383,13 +393,20 @@ def test_forward(): 0].resblock.activation.beta spookynet.core_module.interaction_modules[0].resblock.linear.weight = ref_spookynet.module[0].resblock.linear.weight spookynet.core_module.interaction_modules[0].resblock.linear.bias = ref_spookynet.module[0].resblock.linear.bias + spookynet.core_module.energy_and_charge_readout.weight = ref_spookynet.output.weight + + ref_spookynet.train() + # TODO: how are multiple systems passed into the reference SpookyNet + print(f"test: {model_input.atomic_subsystem_indices.dtype=}") reference_calculated_results = ref_spookynet( model_input.atomic_numbers, model_input.total_charge, (model_input.positions * unit.nanometer).m_as(unit.angstrom), pairlist_output.pair_indices[0], pairlist_output.pair_indices[1], + batch_seg=model_input.atomic_subsystem_indices.long(), + num_batch=2, ) ic(calculated_results.keys()) From d170b79c93ba1829527eed58b0f323961c1446a3 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Tue, 13 Aug 2024 08:04:15 -0700 Subject: [PATCH 76/78] Update docstrings --- modelforge/potential/spookynet.py | 82 ++++---- modelforge/tests/test_spookynet.py | 322 ++++++++++++++--------------- 2 files changed, 202 insertions(+), 202 deletions(-) diff --git a/modelforge/potential/spookynet.py b/modelforge/potential/spookynet.py index 617d4a88..cf7dc6b8 100644 --- a/modelforge/potential/spookynet.py +++ b/modelforge/potential/spookynet.py @@ -223,7 +223,7 @@ def compute_properties( for interaction in self.interaction_modules: x, y = interaction( x=x, - pairlist=data.pair_indices, + pair_indices=data.pair_indices, filters=representation["filters"], dir_ij=representation["dir_ij"], d_orbital_ij=representation["d_orbital_ij"], @@ -776,34 +776,34 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class SpookyNetLocalInteraction(nn.Module): """ Block for updating atomic features through local interactions with - neighboring atoms (message-passing). + neighboring atoms (message-passing) described in Eq. 12. Arguments: number_of_atom_features (int): Dimensions of feature space. number_of_radial_basis_functions (int): Number of radial basis functions. - num_residual_x (int): - TODO: - num_residual_s (int): - TODO: - num_residual_p (int): - TODO: - num_residual_d (int): - TODO: - num_residual (int): - Number of residual blocks to be stacked in sequence. + num_residual_local_x (int): + Number of residual blocks applied to atomic features in resmlp_c in Eq. 12. + num_residual_local_s (int): + Number of residual blocks applied to atomic features in resmlp_s in Eq. 12. + num_residual_local_p (int): + Number of residual blocks applied to atomic features in resmlp_p in Eq. 12. + num_residual_local_d (int): + Number of residual blocks applied to atomic features in resmlp_d in Eq. 12. + num_residual_local (int): + Number of residual blocks applied to atomic features in resmlp_l in Eq. 12. """ def __init__( self, number_of_atom_features: int, number_of_radial_basis_functions: int, - num_residual_x: int, - num_residual_s: int, - num_residual_p: int, - num_residual_d: int, - num_residual: int, + num_residual_local_x: int, + num_residual_local_s: int, + num_residual_local_p: int, + num_residual_local_d: int, + num_residual_local: int, ) -> None: """Initializes the LocalInteraction class.""" super(SpookyNetLocalInteraction, self).__init__() @@ -816,17 +816,17 @@ def __init__( self.radial_d = nn.Linear( number_of_radial_basis_functions, number_of_atom_features, bias=False ) - self.resblock_x = SpookyNetResidualMLP(number_of_atom_features, num_residual_x) - self.resblock_s = SpookyNetResidualMLP(number_of_atom_features, num_residual_s) - self.resblock_p = SpookyNetResidualMLP(number_of_atom_features, num_residual_p) - self.resblock_d = SpookyNetResidualMLP(number_of_atom_features, num_residual_d) + self.resmlp_x = SpookyNetResidualMLP(number_of_atom_features, num_residual_local_x) + self.resmlp_s = SpookyNetResidualMLP(number_of_atom_features, num_residual_local_s) + self.resmlp_p = SpookyNetResidualMLP(number_of_atom_features, num_residual_local_p) + self.resmlp_d = SpookyNetResidualMLP(number_of_atom_features, num_residual_local_d) self.projection_p = nn.Linear( number_of_atom_features, 2 * number_of_atom_features, bias=False ) self.projection_d = nn.Linear( number_of_atom_features, 2 * number_of_atom_features, bias=False ) - self.resblock = SpookyNetResidualMLP(number_of_atom_features, num_residual) + self.resmlp_l = SpookyNetResidualMLP(number_of_atom_features, num_residual_local) self.reset_parameters() def reset_parameters(self) -> None: @@ -871,10 +871,10 @@ def forward( gp = torch.einsum("pf,pr->prf", self.radial_p(f_ij_after_cutoff), dir_ij) gd = torch.einsum("pf,pr->prf", self.radial_d(f_ij_after_cutoff), d_orbital_ij) # atom featurizations - xx = self.resblock_x(x_tilde) - xs = self.resblock_s(x_tilde) - xp = self.resblock_p(x_tilde) - xd = self.resblock_d(x_tilde) + xx = self.resmlp_x(x_tilde) + xs = self.resmlp_s(x_tilde) + xp = self.resmlp_p(x_tilde) + xd = self.resmlp_d(x_tilde) # collect neighbors xs = xs[idx_j] # L=0 xp = xp[idx_j] # L=1 @@ -894,7 +894,7 @@ def forward( # n: number_of_atoms_in_system, x: 3 (geometry axis), f: number_of_atom_features ic(pa.shape) ic(f_ij_after_cutoff.shape) - return self.resblock( + return self.resmlp_l( s + torch.einsum("nxf,nxf->nf", pa, pb) + torch.einsum("nxf,nxf->nf", da, db) @@ -1071,15 +1071,15 @@ class SpookyNetInteractionModule(nn.Module): Number of residual blocks applied to atomic features before interaction with neighbouring atoms. num_residual_local_x (int): - TODO + Number of residual blocks applied to atomic features in resmlp_c in Eq. 12. num_residual_local_s (int): - TODO + Number of residual blocks applied to atomic features in resmlp_s in Eq. 12. num_residual_local_p (int): - TODO + Number of residual blocks applied to atomic features in resmlp_p in Eq. 12. num_residual_local_d (int): - TODO + Number of residual blocks applied to atomic features in resmlp_d in Eq. 12. num_residual_local (int): - TODO + Number of residual blocks applied to atomic features in resmlp_l in Eq. 12. num_residual_nonlocal_q (int): Number of residual blocks for queries in nonlocal interactions. num_residual_nonlocal_k (int): @@ -1116,11 +1116,11 @@ def __init__( self.local_interaction = SpookyNetLocalInteraction( number_of_atom_features=number_of_atom_features, number_of_radial_basis_functions=number_of_radial_basis_functions, - num_residual_x=num_residual_local_x, - num_residual_s=num_residual_local_s, - num_residual_p=num_residual_local_p, - num_residual_d=num_residual_local_d, - num_residual=num_residual_local, + num_residual_local_x=num_residual_local_x, + num_residual_local_s=num_residual_local_s, + num_residual_local_p=num_residual_local_p, + num_residual_local_d=num_residual_local_d, + num_residual_local=num_residual_local, ) self.nonlocal_interaction = SpookyNetNonlocalInteraction( number_of_atom_features=number_of_atom_features, @@ -1147,7 +1147,7 @@ def reset_parameters(self) -> None: def forward( self, x: torch.Tensor, - pairlist: torch.Tensor, # shape [n_pairs, 2] + pair_indices: torch.Tensor, # shape [n_pairs, 2] filters: torch.Tensor, # shape [n_pairs, 1, number_of_radial_basis_functions] TODO: why the 1? dir_ij: torch.Tensor, # shape [n_pairs, 1] d_orbital_ij: torch.Tensor, # shape [n_pairs, 1] @@ -1161,8 +1161,8 @@ def forward( Arguments: x (FloatTensor [N, number_of_atom_features]): Latent atomic feature vectors. - pairlist: - TODO: + pair_indices : + Indices of atom pairs within the maximum interaction radius. filters: TODO: dir_ij (FloatTensor [P, 3]): @@ -1176,7 +1176,7 @@ def forward( Contribution to output atomic features (environment descriptors). """ - idx_i, idx_j = pairlist[0], pairlist[1] + idx_i, idx_j = pair_indices[0], pair_indices[1] x_tilde = self.residual_pre(x) del x l = self.local_interaction( diff --git a/modelforge/tests/test_spookynet.py b/modelforge/tests/test_spookynet.py index 4160f73d..e21dea18 100644 --- a/modelforge/tests/test_spookynet.py +++ b/modelforge/tests/test_spookynet.py @@ -101,8 +101,8 @@ def test_forward(): cutoff=_convert(config["potential"]["core_parameter"]["cutoff"]).m_as(unit.angstrom), ).double() - spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.residual.stack[0].activation2.beta = \ - ref_spookynet.module[0].local_interaction.resblock_d.residual.stack[0].activation2.beta + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_d.residual.stack[0].activation2.beta = \ + ref_spookynet.module[0].local_interaction.resmlp_d.residual.stack[0].activation2.beta for param in ref_spookynet.parameters(): torch.nn.init.normal_(param, 5.0, 3.0) @@ -115,21 +115,21 @@ def test_forward(): spookynet.core_module.charge_embedding_module.linear_q.bias = ref_spookynet.charge_embedding.linear_q.bias spookynet.core_module.charge_embedding_module.linear_k.weight = ref_spookynet.charge_embedding.linear_k.weight spookynet.core_module.charge_embedding_module.linear_v.weight = ref_spookynet.charge_embedding.linear_v.weight - spookynet.core_module.charge_embedding_module.resblock.residual.stack[0].activation1.alpha = \ - ref_spookynet.charge_embedding.resblock.residual.stack[0].activation1.alpha - spookynet.core_module.charge_embedding_module.resblock.residual.stack[0].activation1.beta = \ - ref_spookynet.charge_embedding.resblock.residual.stack[0].activation1.beta - spookynet.core_module.charge_embedding_module.resblock.residual.stack[0].linear1.weight = \ - ref_spookynet.charge_embedding.resblock.residual.stack[0].linear1.weight - spookynet.core_module.charge_embedding_module.resblock.residual.stack[0].activation2.alpha = \ - ref_spookynet.charge_embedding.resblock.residual.stack[0].activation2.alpha - spookynet.core_module.charge_embedding_module.resblock.residual.stack[0].activation2.beta = \ - ref_spookynet.charge_embedding.resblock.residual.stack[0].activation2.beta - spookynet.core_module.charge_embedding_module.resblock.residual.stack[0].linear2.weight = \ - ref_spookynet.charge_embedding.resblock.residual.stack[0].linear2.weight - spookynet.core_module.charge_embedding_module.resblock.activation.alpha = ref_spookynet.charge_embedding.resblock.activation.alpha - spookynet.core_module.charge_embedding_module.resblock.activation.beta = ref_spookynet.charge_embedding.resblock.activation.beta - spookynet.core_module.charge_embedding_module.resblock.linear.weight = ref_spookynet.charge_embedding.resblock.linear.weight + spookynet.core_module.charge_embedding_module.resmlp_l.residual.stack[0].activation1.alpha = \ + ref_spookynet.charge_embedding.resmlp_l.residual.stack[0].activation1.alpha + spookynet.core_module.charge_embedding_module.resmlp_l.residual.stack[0].activation1.beta = \ + ref_spookynet.charge_embedding.resmlp_l.residual.stack[0].activation1.beta + spookynet.core_module.charge_embedding_module.resmlp_l.residual.stack[0].linear1.weight = \ + ref_spookynet.charge_embedding.resmlp_l.residual.stack[0].linear1.weight + spookynet.core_module.charge_embedding_module.resmlp_l.residual.stack[0].activation2.alpha = \ + ref_spookynet.charge_embedding.resmlp_l.residual.stack[0].activation2.alpha + spookynet.core_module.charge_embedding_module.resmlp_l.residual.stack[0].activation2.beta = \ + ref_spookynet.charge_embedding.resmlp_l.residual.stack[0].activation2.beta + spookynet.core_module.charge_embedding_module.resmlp_l.residual.stack[0].linear2.weight = \ + ref_spookynet.charge_embedding.resmlp_l.residual.stack[0].linear2.weight + spookynet.core_module.charge_embedding_module.resmlp_l.activation.alpha = ref_spookynet.charge_embedding.resmlp_l.activation.alpha + spookynet.core_module.charge_embedding_module.resmlp_l.activation.beta = ref_spookynet.charge_embedding.resmlp_l.activation.beta + spookynet.core_module.charge_embedding_module.resmlp_l.linear.weight = ref_spookynet.charge_embedding.resmlp_l.linear.weight spookynet.core_module.spookynet_representation_module.radial_symmetry_function_module.alpha = ref_spookynet.radial_basis_functions._alpha @@ -139,130 +139,130 @@ def test_forward(): 0].local_interaction.radial_p.weight spookynet.core_module.interaction_modules[0].local_interaction.radial_d.weight = ref_spookynet.module[ 0].local_interaction.radial_d.weight - spookynet.core_module.interaction_modules[0].local_interaction.resblock_x.residual.stack[0].activation1.alpha = \ - ref_spookynet.module[0].local_interaction.resblock_x.residual.stack[0].activation1.alpha - spookynet.core_module.interaction_modules[0].local_interaction.resblock_x.residual.stack[0].activation1.beta = \ - ref_spookynet.module[0].local_interaction.resblock_x.residual.stack[0].activation1.beta - spookynet.core_module.interaction_modules[0].local_interaction.resblock_x.residual.stack[0].linear1.weight = \ - ref_spookynet.module[0].local_interaction.resblock_x.residual.stack[0].linear1.weight - spookynet.core_module.interaction_modules[0].local_interaction.resblock_x.residual.stack[0].linear1.bias = \ - ref_spookynet.module[0].local_interaction.resblock_x.residual.stack[0].linear1.bias - spookynet.core_module.interaction_modules[0].local_interaction.resblock_x.residual.stack[0].activation2.alpha = \ - ref_spookynet.module[0].local_interaction.resblock_x.residual.stack[0].activation2.alpha - spookynet.core_module.interaction_modules[0].local_interaction.resblock_x.residual.stack[0].activation2.beta = \ - ref_spookynet.module[0].local_interaction.resblock_x.residual.stack[0].activation2.beta - spookynet.core_module.interaction_modules[0].local_interaction.resblock_x.residual.stack[0].linear2.weight = \ - ref_spookynet.module[0].local_interaction.resblock_x.residual.stack[0].linear2.weight - spookynet.core_module.interaction_modules[0].local_interaction.resblock_x.residual.stack[0].linear2.bias = \ - ref_spookynet.module[0].local_interaction.resblock_x.residual.stack[0].linear2.bias - spookynet.core_module.interaction_modules[0].local_interaction.resblock_x.activation.alpha = ref_spookynet.module[ - 0].local_interaction.resblock_x.activation.alpha - spookynet.core_module.interaction_modules[0].local_interaction.resblock_x.activation.beta = ref_spookynet.module[ - 0].local_interaction.resblock_x.activation.beta - spookynet.core_module.interaction_modules[0].local_interaction.resblock_x.linear.weight = ref_spookynet.module[ - 0].local_interaction.resblock_x.linear.weight - spookynet.core_module.interaction_modules[0].local_interaction.resblock_x.linear.bias = ref_spookynet.module[ - 0].local_interaction.resblock_x.linear.bias - spookynet.core_module.interaction_modules[0].local_interaction.resblock_s.residual.stack[0].activation1.alpha = \ - ref_spookynet.module[0].local_interaction.resblock_s.residual.stack[0].activation1.alpha - spookynet.core_module.interaction_modules[0].local_interaction.resblock_s.residual.stack[0].activation1.beta = \ - ref_spookynet.module[0].local_interaction.resblock_s.residual.stack[0].activation1.beta - spookynet.core_module.interaction_modules[0].local_interaction.resblock_s.residual.stack[0].linear1.weight = \ - ref_spookynet.module[0].local_interaction.resblock_s.residual.stack[0].linear1.weight - spookynet.core_module.interaction_modules[0].local_interaction.resblock_s.residual.stack[0].linear1.bias = \ - ref_spookynet.module[0].local_interaction.resblock_s.residual.stack[0].linear1.bias - spookynet.core_module.interaction_modules[0].local_interaction.resblock_s.residual.stack[0].activation2.alpha = \ - ref_spookynet.module[0].local_interaction.resblock_s.residual.stack[0].activation2.alpha - spookynet.core_module.interaction_modules[0].local_interaction.resblock_s.residual.stack[0].activation2.beta = \ - ref_spookynet.module[0].local_interaction.resblock_s.residual.stack[0].activation2.beta - spookynet.core_module.interaction_modules[0].local_interaction.resblock_s.residual.stack[0].linear2.weight = \ - ref_spookynet.module[0].local_interaction.resblock_s.residual.stack[0].linear2.weight - spookynet.core_module.interaction_modules[0].local_interaction.resblock_s.residual.stack[0].linear2.bias = \ - ref_spookynet.module[0].local_interaction.resblock_s.residual.stack[0].linear2.bias - spookynet.core_module.interaction_modules[0].local_interaction.resblock_s.activation.alpha = ref_spookynet.module[ - 0].local_interaction.resblock_s.activation.alpha - spookynet.core_module.interaction_modules[0].local_interaction.resblock_s.activation.beta = ref_spookynet.module[ - 0].local_interaction.resblock_s.activation.beta - spookynet.core_module.interaction_modules[0].local_interaction.resblock_s.linear.weight = ref_spookynet.module[ - 0].local_interaction.resblock_s.linear.weight - spookynet.core_module.interaction_modules[0].local_interaction.resblock_s.linear.bias = ref_spookynet.module[ - 0].local_interaction.resblock_s.linear.bias - spookynet.core_module.interaction_modules[0].local_interaction.resblock_p.residual.stack[0].activation1.alpha = \ - ref_spookynet.module[0].local_interaction.resblock_p.residual.stack[0].activation1.alpha - spookynet.core_module.interaction_modules[0].local_interaction.resblock_p.residual.stack[0].activation1.beta = \ - ref_spookynet.module[0].local_interaction.resblock_p.residual.stack[0].activation1.beta - spookynet.core_module.interaction_modules[0].local_interaction.resblock_p.residual.stack[0].linear1.weight = \ - ref_spookynet.module[0].local_interaction.resblock_p.residual.stack[0].linear1.weight - spookynet.core_module.interaction_modules[0].local_interaction.resblock_p.residual.stack[0].linear1.bias = \ - ref_spookynet.module[0].local_interaction.resblock_p.residual.stack[0].linear1.bias - spookynet.core_module.interaction_modules[0].local_interaction.resblock_p.residual.stack[0].activation2.alpha = \ - ref_spookynet.module[0].local_interaction.resblock_p.residual.stack[0].activation2.alpha - spookynet.core_module.interaction_modules[0].local_interaction.resblock_p.residual.stack[0].activation2.beta = \ - ref_spookynet.module[0].local_interaction.resblock_p.residual.stack[0].activation2.beta - spookynet.core_module.interaction_modules[0].local_interaction.resblock_p.residual.stack[0].linear2.weight = \ - ref_spookynet.module[0].local_interaction.resblock_p.residual.stack[0].linear2.weight - spookynet.core_module.interaction_modules[0].local_interaction.resblock_p.residual.stack[0].linear2.bias = \ - ref_spookynet.module[0].local_interaction.resblock_p.residual.stack[0].linear2.bias - spookynet.core_module.interaction_modules[0].local_interaction.resblock_p.activation.alpha = ref_spookynet.module[ - 0].local_interaction.resblock_p.activation.alpha - spookynet.core_module.interaction_modules[0].local_interaction.resblock_p.activation.beta = ref_spookynet.module[ - 0].local_interaction.resblock_p.activation.beta - spookynet.core_module.interaction_modules[0].local_interaction.resblock_p.linear.weight = ref_spookynet.module[ - 0].local_interaction.resblock_p.linear.weight - spookynet.core_module.interaction_modules[0].local_interaction.resblock_p.linear.bias = ref_spookynet.module[ - 0].local_interaction.resblock_p.linear.bias - spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.residual.stack[0].activation1.alpha = \ - ref_spookynet.module[0].local_interaction.resblock_d.residual.stack[0].activation1.alpha - spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.residual.stack[0].activation1.beta = \ - ref_spookynet.module[0].local_interaction.resblock_d.residual.stack[0].activation1.beta - spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.residual.stack[0].linear1.weight = \ - ref_spookynet.module[0].local_interaction.resblock_d.residual.stack[0].linear1.weight - spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.residual.stack[0].linear1.bias = \ - ref_spookynet.module[0].local_interaction.resblock_d.residual.stack[0].linear1.bias - spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.residual.stack[0].activation2.alpha = \ - ref_spookynet.module[0].local_interaction.resblock_d.residual.stack[0].activation2.alpha - spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.residual.stack[0].activation2.beta = \ - ref_spookynet.module[0].local_interaction.resblock_d.residual.stack[0].activation2.beta - spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.residual.stack[0].linear2.weight = \ - ref_spookynet.module[0].local_interaction.resblock_d.residual.stack[0].linear2.weight - spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.residual.stack[0].linear2.bias = \ - ref_spookynet.module[0].local_interaction.resblock_d.residual.stack[0].linear2.bias - spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.activation.alpha = ref_spookynet.module[ - 0].local_interaction.resblock_d.activation.alpha - spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.activation.beta = ref_spookynet.module[ - 0].local_interaction.resblock_d.activation.beta - spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.linear.weight = ref_spookynet.module[ - 0].local_interaction.resblock_d.linear.weight - spookynet.core_module.interaction_modules[0].local_interaction.resblock_d.linear.bias = ref_spookynet.module[ - 0].local_interaction.resblock_d.linear.bias + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_x.residual.stack[0].activation1.alpha = \ + ref_spookynet.module[0].local_interaction.resmlp_x.residual.stack[0].activation1.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_x.residual.stack[0].activation1.beta = \ + ref_spookynet.module[0].local_interaction.resmlp_x.residual.stack[0].activation1.beta + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_x.residual.stack[0].linear1.weight = \ + ref_spookynet.module[0].local_interaction.resmlp_x.residual.stack[0].linear1.weight + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_x.residual.stack[0].linear1.bias = \ + ref_spookynet.module[0].local_interaction.resmlp_x.residual.stack[0].linear1.bias + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_x.residual.stack[0].activation2.alpha = \ + ref_spookynet.module[0].local_interaction.resmlp_x.residual.stack[0].activation2.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_x.residual.stack[0].activation2.beta = \ + ref_spookynet.module[0].local_interaction.resmlp_x.residual.stack[0].activation2.beta + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_x.residual.stack[0].linear2.weight = \ + ref_spookynet.module[0].local_interaction.resmlp_x.residual.stack[0].linear2.weight + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_x.residual.stack[0].linear2.bias = \ + ref_spookynet.module[0].local_interaction.resmlp_x.residual.stack[0].linear2.bias + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_x.activation.alpha = ref_spookynet.module[ + 0].local_interaction.resmlp_x.activation.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_x.activation.beta = ref_spookynet.module[ + 0].local_interaction.resmlp_x.activation.beta + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_x.linear.weight = ref_spookynet.module[ + 0].local_interaction.resmlp_x.linear.weight + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_x.linear.bias = ref_spookynet.module[ + 0].local_interaction.resmlp_x.linear.bias + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_s.residual.stack[0].activation1.alpha = \ + ref_spookynet.module[0].local_interaction.resmlp_s.residual.stack[0].activation1.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_s.residual.stack[0].activation1.beta = \ + ref_spookynet.module[0].local_interaction.resmlp_s.residual.stack[0].activation1.beta + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_s.residual.stack[0].linear1.weight = \ + ref_spookynet.module[0].local_interaction.resmlp_s.residual.stack[0].linear1.weight + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_s.residual.stack[0].linear1.bias = \ + ref_spookynet.module[0].local_interaction.resmlp_s.residual.stack[0].linear1.bias + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_s.residual.stack[0].activation2.alpha = \ + ref_spookynet.module[0].local_interaction.resmlp_s.residual.stack[0].activation2.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_s.residual.stack[0].activation2.beta = \ + ref_spookynet.module[0].local_interaction.resmlp_s.residual.stack[0].activation2.beta + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_s.residual.stack[0].linear2.weight = \ + ref_spookynet.module[0].local_interaction.resmlp_s.residual.stack[0].linear2.weight + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_s.residual.stack[0].linear2.bias = \ + ref_spookynet.module[0].local_interaction.resmlp_s.residual.stack[0].linear2.bias + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_s.activation.alpha = ref_spookynet.module[ + 0].local_interaction.resmlp_s.activation.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_s.activation.beta = ref_spookynet.module[ + 0].local_interaction.resmlp_s.activation.beta + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_s.linear.weight = ref_spookynet.module[ + 0].local_interaction.resmlp_s.linear.weight + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_s.linear.bias = ref_spookynet.module[ + 0].local_interaction.resmlp_s.linear.bias + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_p.residual.stack[0].activation1.alpha = \ + ref_spookynet.module[0].local_interaction.resmlp_p.residual.stack[0].activation1.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_p.residual.stack[0].activation1.beta = \ + ref_spookynet.module[0].local_interaction.resmlp_p.residual.stack[0].activation1.beta + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_p.residual.stack[0].linear1.weight = \ + ref_spookynet.module[0].local_interaction.resmlp_p.residual.stack[0].linear1.weight + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_p.residual.stack[0].linear1.bias = \ + ref_spookynet.module[0].local_interaction.resmlp_p.residual.stack[0].linear1.bias + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_p.residual.stack[0].activation2.alpha = \ + ref_spookynet.module[0].local_interaction.resmlp_p.residual.stack[0].activation2.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_p.residual.stack[0].activation2.beta = \ + ref_spookynet.module[0].local_interaction.resmlp_p.residual.stack[0].activation2.beta + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_p.residual.stack[0].linear2.weight = \ + ref_spookynet.module[0].local_interaction.resmlp_p.residual.stack[0].linear2.weight + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_p.residual.stack[0].linear2.bias = \ + ref_spookynet.module[0].local_interaction.resmlp_p.residual.stack[0].linear2.bias + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_p.activation.alpha = ref_spookynet.module[ + 0].local_interaction.resmlp_p.activation.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_p.activation.beta = ref_spookynet.module[ + 0].local_interaction.resmlp_p.activation.beta + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_p.linear.weight = ref_spookynet.module[ + 0].local_interaction.resmlp_p.linear.weight + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_p.linear.bias = ref_spookynet.module[ + 0].local_interaction.resmlp_p.linear.bias + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_d.residual.stack[0].activation1.alpha = \ + ref_spookynet.module[0].local_interaction.resmlp_d.residual.stack[0].activation1.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_d.residual.stack[0].activation1.beta = \ + ref_spookynet.module[0].local_interaction.resmlp_d.residual.stack[0].activation1.beta + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_d.residual.stack[0].linear1.weight = \ + ref_spookynet.module[0].local_interaction.resmlp_d.residual.stack[0].linear1.weight + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_d.residual.stack[0].linear1.bias = \ + ref_spookynet.module[0].local_interaction.resmlp_d.residual.stack[0].linear1.bias + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_d.residual.stack[0].activation2.alpha = \ + ref_spookynet.module[0].local_interaction.resmlp_d.residual.stack[0].activation2.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_d.residual.stack[0].activation2.beta = \ + ref_spookynet.module[0].local_interaction.resmlp_d.residual.stack[0].activation2.beta + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_d.residual.stack[0].linear2.weight = \ + ref_spookynet.module[0].local_interaction.resmlp_d.residual.stack[0].linear2.weight + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_d.residual.stack[0].linear2.bias = \ + ref_spookynet.module[0].local_interaction.resmlp_d.residual.stack[0].linear2.bias + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_d.activation.alpha = ref_spookynet.module[ + 0].local_interaction.resmlp_d.activation.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_d.activation.beta = ref_spookynet.module[ + 0].local_interaction.resmlp_d.activation.beta + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_d.linear.weight = ref_spookynet.module[ + 0].local_interaction.resmlp_d.linear.weight + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_d.linear.bias = ref_spookynet.module[ + 0].local_interaction.resmlp_d.linear.bias spookynet.core_module.interaction_modules[0].local_interaction.projection_p.weight = ref_spookynet.module[ 0].local_interaction.projection_p.weight spookynet.core_module.interaction_modules[0].local_interaction.projection_d.weight = ref_spookynet.module[ 0].local_interaction.projection_d.weight - spookynet.core_module.interaction_modules[0].local_interaction.resblock.residual.stack[0].activation1.alpha = \ - ref_spookynet.module[0].local_interaction.resblock.residual.stack[0].activation1.alpha - spookynet.core_module.interaction_modules[0].local_interaction.resblock.residual.stack[0].activation1.beta = \ - ref_spookynet.module[0].local_interaction.resblock.residual.stack[0].activation1.beta - spookynet.core_module.interaction_modules[0].local_interaction.resblock.residual.stack[0].linear1.weight = \ - ref_spookynet.module[0].local_interaction.resblock.residual.stack[0].linear1.weight - spookynet.core_module.interaction_modules[0].local_interaction.resblock.residual.stack[0].linear1.bias = \ - ref_spookynet.module[0].local_interaction.resblock.residual.stack[0].linear1.bias - spookynet.core_module.interaction_modules[0].local_interaction.resblock.residual.stack[0].activation2.alpha = \ - ref_spookynet.module[0].local_interaction.resblock.residual.stack[0].activation2.alpha - spookynet.core_module.interaction_modules[0].local_interaction.resblock.residual.stack[0].activation2.beta = \ - ref_spookynet.module[0].local_interaction.resblock.residual.stack[0].activation2.beta - spookynet.core_module.interaction_modules[0].local_interaction.resblock.residual.stack[0].linear2.weight = \ - ref_spookynet.module[0].local_interaction.resblock.residual.stack[0].linear2.weight - spookynet.core_module.interaction_modules[0].local_interaction.resblock.residual.stack[0].linear2.bias = \ - ref_spookynet.module[0].local_interaction.resblock.residual.stack[0].linear2.bias - spookynet.core_module.interaction_modules[0].local_interaction.resblock.activation.alpha = ref_spookynet.module[ - 0].local_interaction.resblock.activation.alpha - spookynet.core_module.interaction_modules[0].local_interaction.resblock.activation.beta = ref_spookynet.module[ - 0].local_interaction.resblock.activation.beta - spookynet.core_module.interaction_modules[0].local_interaction.resblock.linear.weight = ref_spookynet.module[ - 0].local_interaction.resblock.linear.weight - spookynet.core_module.interaction_modules[0].local_interaction.resblock.linear.bias = ref_spookynet.module[ - 0].local_interaction.resblock.linear.bias + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_l.residual.stack[0].activation1.alpha = \ + ref_spookynet.module[0].local_interaction.resmlp_l.residual.stack[0].activation1.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_l.residual.stack[0].activation1.beta = \ + ref_spookynet.module[0].local_interaction.resmlp_l.residual.stack[0].activation1.beta + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_l.residual.stack[0].linear1.weight = \ + ref_spookynet.module[0].local_interaction.resmlp_l.residual.stack[0].linear1.weight + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_l.residual.stack[0].linear1.bias = \ + ref_spookynet.module[0].local_interaction.resmlp_l.residual.stack[0].linear1.bias + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_l.residual.stack[0].activation2.alpha = \ + ref_spookynet.module[0].local_interaction.resmlp_l.residual.stack[0].activation2.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_l.residual.stack[0].activation2.beta = \ + ref_spookynet.module[0].local_interaction.resmlp_l.residual.stack[0].activation2.beta + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_l.residual.stack[0].linear2.weight = \ + ref_spookynet.module[0].local_interaction.resmlp_l.residual.stack[0].linear2.weight + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_l.residual.stack[0].linear2.bias = \ + ref_spookynet.module[0].local_interaction.resmlp_l.residual.stack[0].linear2.bias + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_l.activation.alpha = ref_spookynet.module[ + 0].local_interaction.resmlp_l.activation.alpha + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_l.activation.beta = ref_spookynet.module[ + 0].local_interaction.resmlp_l.activation.beta + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_l.linear.weight = ref_spookynet.module[ + 0].local_interaction.resmlp_l.linear.weight + spookynet.core_module.interaction_modules[0].local_interaction.resmlp_l.linear.bias = ref_spookynet.module[ + 0].local_interaction.resmlp_l.linear.bias spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_q.residual.stack[0].activation1.alpha = \ ref_spookynet.module[0].nonlocal_interaction.resblock_q.residual.stack[0].activation1.alpha @@ -371,28 +371,28 @@ def test_forward(): spookynet.core_module.interaction_modules[0].residual_post.stack[0].linear2.bias = \ ref_spookynet.module[0].residual_post.stack[0].linear2.bias - spookynet.core_module.interaction_modules[0].resblock.residual.stack[0].activation1.alpha = \ - ref_spookynet.module[0].resblock.residual.stack[0].activation1.alpha - spookynet.core_module.interaction_modules[0].resblock.residual.stack[0].activation1.beta = \ - ref_spookynet.module[0].resblock.residual.stack[0].activation1.beta - spookynet.core_module.interaction_modules[0].resblock.residual.stack[0].linear1.weight = \ - ref_spookynet.module[0].resblock.residual.stack[0].linear1.weight - spookynet.core_module.interaction_modules[0].resblock.residual.stack[0].linear1.bias = \ - ref_spookynet.module[0].resblock.residual.stack[0].linear1.bias - spookynet.core_module.interaction_modules[0].resblock.residual.stack[0].activation2.alpha = \ - ref_spookynet.module[0].resblock.residual.stack[0].activation2.alpha - spookynet.core_module.interaction_modules[0].resblock.residual.stack[0].activation2.beta = \ - ref_spookynet.module[0].resblock.residual.stack[0].activation2.beta - spookynet.core_module.interaction_modules[0].resblock.residual.stack[0].linear2.weight = \ - ref_spookynet.module[0].resblock.residual.stack[0].linear2.weight - spookynet.core_module.interaction_modules[0].resblock.residual.stack[0].linear2.bias = \ - ref_spookynet.module[0].resblock.residual.stack[0].linear2.bias - spookynet.core_module.interaction_modules[0].resblock.activation.alpha = ref_spookynet.module[ - 0].resblock.activation.alpha - spookynet.core_module.interaction_modules[0].resblock.activation.beta = ref_spookynet.module[ - 0].resblock.activation.beta - spookynet.core_module.interaction_modules[0].resblock.linear.weight = ref_spookynet.module[0].resblock.linear.weight - spookynet.core_module.interaction_modules[0].resblock.linear.bias = ref_spookynet.module[0].resblock.linear.bias + spookynet.core_module.interaction_modules[0].resmlp_l.residual.stack[0].activation1.alpha = \ + ref_spookynet.module[0].resmlp_l.residual.stack[0].activation1.alpha + spookynet.core_module.interaction_modules[0].resmlp_l.residual.stack[0].activation1.beta = \ + ref_spookynet.module[0].resmlp_l.residual.stack[0].activation1.beta + spookynet.core_module.interaction_modules[0].resmlp_l.residual.stack[0].linear1.weight = \ + ref_spookynet.module[0].resmlp_l.residual.stack[0].linear1.weight + spookynet.core_module.interaction_modules[0].resmlp_l.residual.stack[0].linear1.bias = \ + ref_spookynet.module[0].resmlp_l.residual.stack[0].linear1.bias + spookynet.core_module.interaction_modules[0].resmlp_l.residual.stack[0].activation2.alpha = \ + ref_spookynet.module[0].resmlp_l.residual.stack[0].activation2.alpha + spookynet.core_module.interaction_modules[0].resmlp_l.residual.stack[0].activation2.beta = \ + ref_spookynet.module[0].resmlp_l.residual.stack[0].activation2.beta + spookynet.core_module.interaction_modules[0].resmlp_l.residual.stack[0].linear2.weight = \ + ref_spookynet.module[0].resmlp_l.residual.stack[0].linear2.weight + spookynet.core_module.interaction_modules[0].resmlp_l.residual.stack[0].linear2.bias = \ + ref_spookynet.module[0].resmlp_l.residual.stack[0].linear2.bias + spookynet.core_module.interaction_modules[0].resmlp_l.activation.alpha = ref_spookynet.module[ + 0].resmlp_l.activation.alpha + spookynet.core_module.interaction_modules[0].resmlp_l.activation.beta = ref_spookynet.module[ + 0].resmlp_l.activation.beta + spookynet.core_module.interaction_modules[0].resmlp_l.linear.weight = ref_spookynet.module[0].resmlp_l.linear.weight + spookynet.core_module.interaction_modules[0].resmlp_l.linear.bias = ref_spookynet.module[0].resmlp_l.linear.bias spookynet.core_module.energy_and_charge_readout.weight = ref_spookynet.output.weight ref_spookynet.train() @@ -557,9 +557,9 @@ def test_spookynet_interaction_module_against_reference(): mf_param[:] = ref_param assert len(list(mf_spookynet_interaction_module.resblock.named_parameters())) == len( - list(ref_spookynet_interaction_module.resblock.named_parameters())) + list(ref_spookynet_interaction_module.resmlp_l.named_parameters())) for (mf_name, mf_param), (ref_name, ref_param) in zip(mf_spookynet_interaction_module.resblock.named_parameters(), - ref_spookynet_interaction_module.resblock.named_parameters()): + ref_spookynet_interaction_module.resmlp_l.named_parameters()): print(f"{mf_name=} {ref_name=}") if not torch.equal(mf_param, ref_param): print(f"{mf_param=} {ref_param=}") From 5f9666d561ba901fe6ddd5b71c5fc67b65a5e1ec Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Tue, 13 Aug 2024 10:05:03 -0700 Subject: [PATCH 77/78] Fix naming resblock -> resmlp --- modelforge/potential/spookynet.py | 20 +-- modelforge/tests/test_spookynet.py | 273 ++++++++++++++--------------- 2 files changed, 145 insertions(+), 148 deletions(-) diff --git a/modelforge/potential/spookynet.py b/modelforge/potential/spookynet.py index cf7dc6b8..6e7a2764 100644 --- a/modelforge/potential/spookynet.py +++ b/modelforge/potential/spookynet.py @@ -464,7 +464,7 @@ def __init__( # charges are duplicated to use separate weights for +/- self.linear_k = nn.Linear(2, num_features, bias=False) self.linear_v = nn.Linear(2, num_features, bias=False) - self.resblock = SpookyNetResidualMLP( + self.resmlp = SpookyNetResidualMLP( num_features, num_residual, bias=False, @@ -504,7 +504,7 @@ def forward( dot = torch.einsum("nf,nf->n", k, q) / math.sqrt(k.shape[-1]) # scaled dot product a = nn.functional.softplus(dot) # unnormalized attention weights a_normalized = a / (a.sum(-1) + eps) # TODO: why is this needed? shouldn't softplus add up to 1? - return self.resblock(torch.einsum("n,nf->nf", a_normalized, v)) + return self.resmlp(torch.einsum("n,nf->nf", a_normalized, v)) class SpookyNetRepresentation(nn.Module): @@ -1029,9 +1029,9 @@ def __init__( ) -> None: """Initializes the NonlocalInteraction class.""" super(SpookyNetNonlocalInteraction, self).__init__() - self.resblock_q = SpookyNetResidualMLP(number_of_atom_features, num_residual_q) - self.resblock_k = SpookyNetResidualMLP(number_of_atom_features, num_residual_k) - self.resblock_v = SpookyNetResidualMLP(number_of_atom_features, num_residual_v) + self.resmlp_q = SpookyNetResidualMLP(number_of_atom_features, num_residual_q) + self.resmlp_k = SpookyNetResidualMLP(number_of_atom_features, num_residual_k) + self.resmlp_v = SpookyNetResidualMLP(number_of_atom_features, num_residual_v) self.attention = SpookyNetAttention( dim_qk=number_of_atom_features, num_random_features=number_of_atom_features ) @@ -1052,9 +1052,9 @@ def forward( x (FloatTensor [N, number_of_atom_features]): Atomic feature vectors. """ - q = self.resblock_q(x_tilde) # queries - k = self.resblock_k(x_tilde) # keys - v = self.resblock_v(x_tilde) # values + q = self.resmlp_q(x_tilde) # queries + k = self.resmlp_k(x_tilde) # keys + v = self.resmlp_v(x_tilde) # values return self.attention(q, k, v) @@ -1135,7 +1135,7 @@ def __init__( self.residual_post = SpookyNetResidualStack( number_of_atom_features, num_residual_post ) - self.resblock = SpookyNetResidualMLP( + self.resmlp = SpookyNetResidualMLP( number_of_atom_features, num_residual_output ) self.reset_parameters() @@ -1190,4 +1190,4 @@ def forward( n = self.nonlocal_interaction(x_tilde) x_updated = self.residual_post(x_tilde + l + n) del x_tilde - return x_updated, self.resblock(x_updated) + return x_updated, self.resmlp(x_updated) diff --git a/modelforge/tests/test_spookynet.py b/modelforge/tests/test_spookynet.py index e21dea18..c905dfc6 100644 --- a/modelforge/tests/test_spookynet.py +++ b/modelforge/tests/test_spookynet.py @@ -101,9 +101,6 @@ def test_forward(): cutoff=_convert(config["potential"]["core_parameter"]["cutoff"]).m_as(unit.angstrom), ).double() - spookynet.core_module.interaction_modules[0].local_interaction.resmlp_d.residual.stack[0].activation2.beta = \ - ref_spookynet.module[0].local_interaction.resmlp_d.residual.stack[0].activation2.beta - for param in ref_spookynet.parameters(): torch.nn.init.normal_(param, 5.0, 3.0) @@ -115,21 +112,21 @@ def test_forward(): spookynet.core_module.charge_embedding_module.linear_q.bias = ref_spookynet.charge_embedding.linear_q.bias spookynet.core_module.charge_embedding_module.linear_k.weight = ref_spookynet.charge_embedding.linear_k.weight spookynet.core_module.charge_embedding_module.linear_v.weight = ref_spookynet.charge_embedding.linear_v.weight - spookynet.core_module.charge_embedding_module.resmlp_l.residual.stack[0].activation1.alpha = \ - ref_spookynet.charge_embedding.resmlp_l.residual.stack[0].activation1.alpha - spookynet.core_module.charge_embedding_module.resmlp_l.residual.stack[0].activation1.beta = \ - ref_spookynet.charge_embedding.resmlp_l.residual.stack[0].activation1.beta - spookynet.core_module.charge_embedding_module.resmlp_l.residual.stack[0].linear1.weight = \ - ref_spookynet.charge_embedding.resmlp_l.residual.stack[0].linear1.weight - spookynet.core_module.charge_embedding_module.resmlp_l.residual.stack[0].activation2.alpha = \ - ref_spookynet.charge_embedding.resmlp_l.residual.stack[0].activation2.alpha - spookynet.core_module.charge_embedding_module.resmlp_l.residual.stack[0].activation2.beta = \ - ref_spookynet.charge_embedding.resmlp_l.residual.stack[0].activation2.beta - spookynet.core_module.charge_embedding_module.resmlp_l.residual.stack[0].linear2.weight = \ - ref_spookynet.charge_embedding.resmlp_l.residual.stack[0].linear2.weight - spookynet.core_module.charge_embedding_module.resmlp_l.activation.alpha = ref_spookynet.charge_embedding.resmlp_l.activation.alpha - spookynet.core_module.charge_embedding_module.resmlp_l.activation.beta = ref_spookynet.charge_embedding.resmlp_l.activation.beta - spookynet.core_module.charge_embedding_module.resmlp_l.linear.weight = ref_spookynet.charge_embedding.resmlp_l.linear.weight + spookynet.core_module.charge_embedding_module.resmlp.residual.stack[0].activation1.alpha = \ + ref_spookynet.charge_embedding.resblock.residual.stack[0].activation1.alpha + spookynet.core_module.charge_embedding_module.resmlp.residual.stack[0].activation1.beta = \ + ref_spookynet.charge_embedding.resblock.residual.stack[0].activation1.beta + spookynet.core_module.charge_embedding_module.resmlp.residual.stack[0].linear1.weight = \ + ref_spookynet.charge_embedding.resblock.residual.stack[0].linear1.weight + spookynet.core_module.charge_embedding_module.resmlp.residual.stack[0].activation2.alpha = \ + ref_spookynet.charge_embedding.resblock.residual.stack[0].activation2.alpha + spookynet.core_module.charge_embedding_module.resmlp.residual.stack[0].activation2.beta = \ + ref_spookynet.charge_embedding.resblock.residual.stack[0].activation2.beta + spookynet.core_module.charge_embedding_module.resmlp.residual.stack[0].linear2.weight = \ + ref_spookynet.charge_embedding.resblock.residual.stack[0].linear2.weight + spookynet.core_module.charge_embedding_module.resmlp.activation.alpha = ref_spookynet.charge_embedding.resblock.activation.alpha + spookynet.core_module.charge_embedding_module.resmlp.activation.beta = ref_spookynet.charge_embedding.resblock.activation.beta + spookynet.core_module.charge_embedding_module.resmlp.linear.weight = ref_spookynet.charge_embedding.resblock.linear.weight spookynet.core_module.spookynet_representation_module.radial_symmetry_function_module.alpha = ref_spookynet.radial_basis_functions._alpha @@ -140,201 +137,201 @@ def test_forward(): spookynet.core_module.interaction_modules[0].local_interaction.radial_d.weight = ref_spookynet.module[ 0].local_interaction.radial_d.weight spookynet.core_module.interaction_modules[0].local_interaction.resmlp_x.residual.stack[0].activation1.alpha = \ - ref_spookynet.module[0].local_interaction.resmlp_x.residual.stack[0].activation1.alpha + ref_spookynet.module[0].local_interaction.resblock_x.residual.stack[0].activation1.alpha spookynet.core_module.interaction_modules[0].local_interaction.resmlp_x.residual.stack[0].activation1.beta = \ - ref_spookynet.module[0].local_interaction.resmlp_x.residual.stack[0].activation1.beta + ref_spookynet.module[0].local_interaction.resblock_x.residual.stack[0].activation1.beta spookynet.core_module.interaction_modules[0].local_interaction.resmlp_x.residual.stack[0].linear1.weight = \ - ref_spookynet.module[0].local_interaction.resmlp_x.residual.stack[0].linear1.weight + ref_spookynet.module[0].local_interaction.resblock_x.residual.stack[0].linear1.weight spookynet.core_module.interaction_modules[0].local_interaction.resmlp_x.residual.stack[0].linear1.bias = \ - ref_spookynet.module[0].local_interaction.resmlp_x.residual.stack[0].linear1.bias + ref_spookynet.module[0].local_interaction.resblock_x.residual.stack[0].linear1.bias spookynet.core_module.interaction_modules[0].local_interaction.resmlp_x.residual.stack[0].activation2.alpha = \ - ref_spookynet.module[0].local_interaction.resmlp_x.residual.stack[0].activation2.alpha + ref_spookynet.module[0].local_interaction.resblock_x.residual.stack[0].activation2.alpha spookynet.core_module.interaction_modules[0].local_interaction.resmlp_x.residual.stack[0].activation2.beta = \ - ref_spookynet.module[0].local_interaction.resmlp_x.residual.stack[0].activation2.beta + ref_spookynet.module[0].local_interaction.resblock_x.residual.stack[0].activation2.beta spookynet.core_module.interaction_modules[0].local_interaction.resmlp_x.residual.stack[0].linear2.weight = \ - ref_spookynet.module[0].local_interaction.resmlp_x.residual.stack[0].linear2.weight + ref_spookynet.module[0].local_interaction.resblock_x.residual.stack[0].linear2.weight spookynet.core_module.interaction_modules[0].local_interaction.resmlp_x.residual.stack[0].linear2.bias = \ - ref_spookynet.module[0].local_interaction.resmlp_x.residual.stack[0].linear2.bias + ref_spookynet.module[0].local_interaction.resblock_x.residual.stack[0].linear2.bias spookynet.core_module.interaction_modules[0].local_interaction.resmlp_x.activation.alpha = ref_spookynet.module[ - 0].local_interaction.resmlp_x.activation.alpha + 0].local_interaction.resblock_x.activation.alpha spookynet.core_module.interaction_modules[0].local_interaction.resmlp_x.activation.beta = ref_spookynet.module[ - 0].local_interaction.resmlp_x.activation.beta + 0].local_interaction.resblock_x.activation.beta spookynet.core_module.interaction_modules[0].local_interaction.resmlp_x.linear.weight = ref_spookynet.module[ - 0].local_interaction.resmlp_x.linear.weight + 0].local_interaction.resblock_x.linear.weight spookynet.core_module.interaction_modules[0].local_interaction.resmlp_x.linear.bias = ref_spookynet.module[ - 0].local_interaction.resmlp_x.linear.bias + 0].local_interaction.resblock_x.linear.bias spookynet.core_module.interaction_modules[0].local_interaction.resmlp_s.residual.stack[0].activation1.alpha = \ - ref_spookynet.module[0].local_interaction.resmlp_s.residual.stack[0].activation1.alpha + ref_spookynet.module[0].local_interaction.resblock_s.residual.stack[0].activation1.alpha spookynet.core_module.interaction_modules[0].local_interaction.resmlp_s.residual.stack[0].activation1.beta = \ - ref_spookynet.module[0].local_interaction.resmlp_s.residual.stack[0].activation1.beta + ref_spookynet.module[0].local_interaction.resblock_s.residual.stack[0].activation1.beta spookynet.core_module.interaction_modules[0].local_interaction.resmlp_s.residual.stack[0].linear1.weight = \ - ref_spookynet.module[0].local_interaction.resmlp_s.residual.stack[0].linear1.weight + ref_spookynet.module[0].local_interaction.resblock_s.residual.stack[0].linear1.weight spookynet.core_module.interaction_modules[0].local_interaction.resmlp_s.residual.stack[0].linear1.bias = \ - ref_spookynet.module[0].local_interaction.resmlp_s.residual.stack[0].linear1.bias + ref_spookynet.module[0].local_interaction.resblock_s.residual.stack[0].linear1.bias spookynet.core_module.interaction_modules[0].local_interaction.resmlp_s.residual.stack[0].activation2.alpha = \ - ref_spookynet.module[0].local_interaction.resmlp_s.residual.stack[0].activation2.alpha + ref_spookynet.module[0].local_interaction.resblock_s.residual.stack[0].activation2.alpha spookynet.core_module.interaction_modules[0].local_interaction.resmlp_s.residual.stack[0].activation2.beta = \ - ref_spookynet.module[0].local_interaction.resmlp_s.residual.stack[0].activation2.beta + ref_spookynet.module[0].local_interaction.resblock_s.residual.stack[0].activation2.beta spookynet.core_module.interaction_modules[0].local_interaction.resmlp_s.residual.stack[0].linear2.weight = \ - ref_spookynet.module[0].local_interaction.resmlp_s.residual.stack[0].linear2.weight + ref_spookynet.module[0].local_interaction.resblock_s.residual.stack[0].linear2.weight spookynet.core_module.interaction_modules[0].local_interaction.resmlp_s.residual.stack[0].linear2.bias = \ - ref_spookynet.module[0].local_interaction.resmlp_s.residual.stack[0].linear2.bias + ref_spookynet.module[0].local_interaction.resblock_s.residual.stack[0].linear2.bias spookynet.core_module.interaction_modules[0].local_interaction.resmlp_s.activation.alpha = ref_spookynet.module[ - 0].local_interaction.resmlp_s.activation.alpha + 0].local_interaction.resblock_s.activation.alpha spookynet.core_module.interaction_modules[0].local_interaction.resmlp_s.activation.beta = ref_spookynet.module[ - 0].local_interaction.resmlp_s.activation.beta + 0].local_interaction.resblock_s.activation.beta spookynet.core_module.interaction_modules[0].local_interaction.resmlp_s.linear.weight = ref_spookynet.module[ - 0].local_interaction.resmlp_s.linear.weight + 0].local_interaction.resblock_s.linear.weight spookynet.core_module.interaction_modules[0].local_interaction.resmlp_s.linear.bias = ref_spookynet.module[ - 0].local_interaction.resmlp_s.linear.bias + 0].local_interaction.resblock_s.linear.bias spookynet.core_module.interaction_modules[0].local_interaction.resmlp_p.residual.stack[0].activation1.alpha = \ - ref_spookynet.module[0].local_interaction.resmlp_p.residual.stack[0].activation1.alpha + ref_spookynet.module[0].local_interaction.resblock_p.residual.stack[0].activation1.alpha spookynet.core_module.interaction_modules[0].local_interaction.resmlp_p.residual.stack[0].activation1.beta = \ - ref_spookynet.module[0].local_interaction.resmlp_p.residual.stack[0].activation1.beta + ref_spookynet.module[0].local_interaction.resblock_p.residual.stack[0].activation1.beta spookynet.core_module.interaction_modules[0].local_interaction.resmlp_p.residual.stack[0].linear1.weight = \ - ref_spookynet.module[0].local_interaction.resmlp_p.residual.stack[0].linear1.weight + ref_spookynet.module[0].local_interaction.resblock_p.residual.stack[0].linear1.weight spookynet.core_module.interaction_modules[0].local_interaction.resmlp_p.residual.stack[0].linear1.bias = \ - ref_spookynet.module[0].local_interaction.resmlp_p.residual.stack[0].linear1.bias + ref_spookynet.module[0].local_interaction.resblock_p.residual.stack[0].linear1.bias spookynet.core_module.interaction_modules[0].local_interaction.resmlp_p.residual.stack[0].activation2.alpha = \ - ref_spookynet.module[0].local_interaction.resmlp_p.residual.stack[0].activation2.alpha + ref_spookynet.module[0].local_interaction.resblock_p.residual.stack[0].activation2.alpha spookynet.core_module.interaction_modules[0].local_interaction.resmlp_p.residual.stack[0].activation2.beta = \ - ref_spookynet.module[0].local_interaction.resmlp_p.residual.stack[0].activation2.beta + ref_spookynet.module[0].local_interaction.resblock_p.residual.stack[0].activation2.beta spookynet.core_module.interaction_modules[0].local_interaction.resmlp_p.residual.stack[0].linear2.weight = \ - ref_spookynet.module[0].local_interaction.resmlp_p.residual.stack[0].linear2.weight + ref_spookynet.module[0].local_interaction.resblock_p.residual.stack[0].linear2.weight spookynet.core_module.interaction_modules[0].local_interaction.resmlp_p.residual.stack[0].linear2.bias = \ - ref_spookynet.module[0].local_interaction.resmlp_p.residual.stack[0].linear2.bias + ref_spookynet.module[0].local_interaction.resblock_p.residual.stack[0].linear2.bias spookynet.core_module.interaction_modules[0].local_interaction.resmlp_p.activation.alpha = ref_spookynet.module[ - 0].local_interaction.resmlp_p.activation.alpha + 0].local_interaction.resblock_p.activation.alpha spookynet.core_module.interaction_modules[0].local_interaction.resmlp_p.activation.beta = ref_spookynet.module[ - 0].local_interaction.resmlp_p.activation.beta + 0].local_interaction.resblock_p.activation.beta spookynet.core_module.interaction_modules[0].local_interaction.resmlp_p.linear.weight = ref_spookynet.module[ - 0].local_interaction.resmlp_p.linear.weight + 0].local_interaction.resblock_p.linear.weight spookynet.core_module.interaction_modules[0].local_interaction.resmlp_p.linear.bias = ref_spookynet.module[ - 0].local_interaction.resmlp_p.linear.bias + 0].local_interaction.resblock_p.linear.bias spookynet.core_module.interaction_modules[0].local_interaction.resmlp_d.residual.stack[0].activation1.alpha = \ - ref_spookynet.module[0].local_interaction.resmlp_d.residual.stack[0].activation1.alpha + ref_spookynet.module[0].local_interaction.resblock_d.residual.stack[0].activation1.alpha spookynet.core_module.interaction_modules[0].local_interaction.resmlp_d.residual.stack[0].activation1.beta = \ - ref_spookynet.module[0].local_interaction.resmlp_d.residual.stack[0].activation1.beta + ref_spookynet.module[0].local_interaction.resblock_d.residual.stack[0].activation1.beta spookynet.core_module.interaction_modules[0].local_interaction.resmlp_d.residual.stack[0].linear1.weight = \ - ref_spookynet.module[0].local_interaction.resmlp_d.residual.stack[0].linear1.weight + ref_spookynet.module[0].local_interaction.resblock_d.residual.stack[0].linear1.weight spookynet.core_module.interaction_modules[0].local_interaction.resmlp_d.residual.stack[0].linear1.bias = \ - ref_spookynet.module[0].local_interaction.resmlp_d.residual.stack[0].linear1.bias + ref_spookynet.module[0].local_interaction.resblock_d.residual.stack[0].linear1.bias spookynet.core_module.interaction_modules[0].local_interaction.resmlp_d.residual.stack[0].activation2.alpha = \ - ref_spookynet.module[0].local_interaction.resmlp_d.residual.stack[0].activation2.alpha + ref_spookynet.module[0].local_interaction.resblock_d.residual.stack[0].activation2.alpha spookynet.core_module.interaction_modules[0].local_interaction.resmlp_d.residual.stack[0].activation2.beta = \ - ref_spookynet.module[0].local_interaction.resmlp_d.residual.stack[0].activation2.beta + ref_spookynet.module[0].local_interaction.resblock_d.residual.stack[0].activation2.beta spookynet.core_module.interaction_modules[0].local_interaction.resmlp_d.residual.stack[0].linear2.weight = \ - ref_spookynet.module[0].local_interaction.resmlp_d.residual.stack[0].linear2.weight + ref_spookynet.module[0].local_interaction.resblock_d.residual.stack[0].linear2.weight spookynet.core_module.interaction_modules[0].local_interaction.resmlp_d.residual.stack[0].linear2.bias = \ - ref_spookynet.module[0].local_interaction.resmlp_d.residual.stack[0].linear2.bias + ref_spookynet.module[0].local_interaction.resblock_d.residual.stack[0].linear2.bias spookynet.core_module.interaction_modules[0].local_interaction.resmlp_d.activation.alpha = ref_spookynet.module[ - 0].local_interaction.resmlp_d.activation.alpha + 0].local_interaction.resblock_d.activation.alpha spookynet.core_module.interaction_modules[0].local_interaction.resmlp_d.activation.beta = ref_spookynet.module[ - 0].local_interaction.resmlp_d.activation.beta + 0].local_interaction.resblock_d.activation.beta spookynet.core_module.interaction_modules[0].local_interaction.resmlp_d.linear.weight = ref_spookynet.module[ - 0].local_interaction.resmlp_d.linear.weight + 0].local_interaction.resblock_d.linear.weight spookynet.core_module.interaction_modules[0].local_interaction.resmlp_d.linear.bias = ref_spookynet.module[ - 0].local_interaction.resmlp_d.linear.bias + 0].local_interaction.resblock_d.linear.bias spookynet.core_module.interaction_modules[0].local_interaction.projection_p.weight = ref_spookynet.module[ 0].local_interaction.projection_p.weight spookynet.core_module.interaction_modules[0].local_interaction.projection_d.weight = ref_spookynet.module[ 0].local_interaction.projection_d.weight spookynet.core_module.interaction_modules[0].local_interaction.resmlp_l.residual.stack[0].activation1.alpha = \ - ref_spookynet.module[0].local_interaction.resmlp_l.residual.stack[0].activation1.alpha + ref_spookynet.module[0].local_interaction.resblock.residual.stack[0].activation1.alpha spookynet.core_module.interaction_modules[0].local_interaction.resmlp_l.residual.stack[0].activation1.beta = \ - ref_spookynet.module[0].local_interaction.resmlp_l.residual.stack[0].activation1.beta + ref_spookynet.module[0].local_interaction.resblock.residual.stack[0].activation1.beta spookynet.core_module.interaction_modules[0].local_interaction.resmlp_l.residual.stack[0].linear1.weight = \ - ref_spookynet.module[0].local_interaction.resmlp_l.residual.stack[0].linear1.weight + ref_spookynet.module[0].local_interaction.resblock.residual.stack[0].linear1.weight spookynet.core_module.interaction_modules[0].local_interaction.resmlp_l.residual.stack[0].linear1.bias = \ - ref_spookynet.module[0].local_interaction.resmlp_l.residual.stack[0].linear1.bias + ref_spookynet.module[0].local_interaction.resblock.residual.stack[0].linear1.bias spookynet.core_module.interaction_modules[0].local_interaction.resmlp_l.residual.stack[0].activation2.alpha = \ - ref_spookynet.module[0].local_interaction.resmlp_l.residual.stack[0].activation2.alpha + ref_spookynet.module[0].local_interaction.resblock.residual.stack[0].activation2.alpha spookynet.core_module.interaction_modules[0].local_interaction.resmlp_l.residual.stack[0].activation2.beta = \ - ref_spookynet.module[0].local_interaction.resmlp_l.residual.stack[0].activation2.beta + ref_spookynet.module[0].local_interaction.resblock.residual.stack[0].activation2.beta spookynet.core_module.interaction_modules[0].local_interaction.resmlp_l.residual.stack[0].linear2.weight = \ - ref_spookynet.module[0].local_interaction.resmlp_l.residual.stack[0].linear2.weight + ref_spookynet.module[0].local_interaction.resblock.residual.stack[0].linear2.weight spookynet.core_module.interaction_modules[0].local_interaction.resmlp_l.residual.stack[0].linear2.bias = \ - ref_spookynet.module[0].local_interaction.resmlp_l.residual.stack[0].linear2.bias + ref_spookynet.module[0].local_interaction.resblock.residual.stack[0].linear2.bias spookynet.core_module.interaction_modules[0].local_interaction.resmlp_l.activation.alpha = ref_spookynet.module[ - 0].local_interaction.resmlp_l.activation.alpha + 0].local_interaction.resblock.activation.alpha spookynet.core_module.interaction_modules[0].local_interaction.resmlp_l.activation.beta = ref_spookynet.module[ - 0].local_interaction.resmlp_l.activation.beta + 0].local_interaction.resblock.activation.beta spookynet.core_module.interaction_modules[0].local_interaction.resmlp_l.linear.weight = ref_spookynet.module[ - 0].local_interaction.resmlp_l.linear.weight + 0].local_interaction.resblock.linear.weight spookynet.core_module.interaction_modules[0].local_interaction.resmlp_l.linear.bias = ref_spookynet.module[ - 0].local_interaction.resmlp_l.linear.bias + 0].local_interaction.resblock.linear.bias - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_q.residual.stack[0].activation1.alpha = \ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_q.residual.stack[0].activation1.alpha = \ ref_spookynet.module[0].nonlocal_interaction.resblock_q.residual.stack[0].activation1.alpha - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_q.residual.stack[0].activation1.beta = \ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_q.residual.stack[0].activation1.beta = \ ref_spookynet.module[0].nonlocal_interaction.resblock_q.residual.stack[0].activation1.beta - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_q.residual.stack[0].linear1.weight = \ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_q.residual.stack[0].linear1.weight = \ ref_spookynet.module[0].nonlocal_interaction.resblock_q.residual.stack[0].linear1.weight - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_q.residual.stack[0].linear1.bias = \ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_q.residual.stack[0].linear1.bias = \ ref_spookynet.module[0].nonlocal_interaction.resblock_q.residual.stack[0].linear1.bias - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_q.residual.stack[0].activation2.alpha = \ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_q.residual.stack[0].activation2.alpha = \ ref_spookynet.module[0].nonlocal_interaction.resblock_q.residual.stack[0].activation2.alpha - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_q.residual.stack[0].activation2.beta = \ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_q.residual.stack[0].activation2.beta = \ ref_spookynet.module[0].nonlocal_interaction.resblock_q.residual.stack[0].activation2.beta - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_q.residual.stack[0].linear2.weight = \ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_q.residual.stack[0].linear2.weight = \ ref_spookynet.module[0].nonlocal_interaction.resblock_q.residual.stack[0].linear2.weight - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_q.residual.stack[0].linear2.bias = \ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_q.residual.stack[0].linear2.bias = \ ref_spookynet.module[0].nonlocal_interaction.resblock_q.residual.stack[0].linear2.bias - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_q.activation.alpha = \ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_q.activation.alpha = \ ref_spookynet.module[0].nonlocal_interaction.resblock_q.activation.alpha - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_q.activation.beta = ref_spookynet.module[ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_q.activation.beta = ref_spookynet.module[ 0].nonlocal_interaction.resblock_q.activation.beta - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_q.linear.weight = ref_spookynet.module[ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_q.linear.weight = ref_spookynet.module[ 0].nonlocal_interaction.resblock_q.linear.weight - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_q.linear.bias = ref_spookynet.module[ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_q.linear.bias = ref_spookynet.module[ 0].nonlocal_interaction.resblock_q.linear.bias - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_k.residual.stack[0].activation1.alpha = \ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_k.residual.stack[0].activation1.alpha = \ ref_spookynet.module[0].nonlocal_interaction.resblock_k.residual.stack[0].activation1.alpha - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_k.residual.stack[0].activation1.beta = \ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_k.residual.stack[0].activation1.beta = \ ref_spookynet.module[0].nonlocal_interaction.resblock_k.residual.stack[0].activation1.beta - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_k.residual.stack[0].linear1.weight = \ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_k.residual.stack[0].linear1.weight = \ ref_spookynet.module[0].nonlocal_interaction.resblock_k.residual.stack[0].linear1.weight - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_k.residual.stack[0].linear1.bias = \ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_k.residual.stack[0].linear1.bias = \ ref_spookynet.module[0].nonlocal_interaction.resblock_k.residual.stack[0].linear1.bias - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_k.residual.stack[0].activation2.alpha = \ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_k.residual.stack[0].activation2.alpha = \ ref_spookynet.module[0].nonlocal_interaction.resblock_k.residual.stack[0].activation2.alpha - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_k.residual.stack[0].activation2.beta = \ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_k.residual.stack[0].activation2.beta = \ ref_spookynet.module[0].nonlocal_interaction.resblock_k.residual.stack[0].activation2.beta - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_k.residual.stack[0].linear2.weight = \ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_k.residual.stack[0].linear2.weight = \ ref_spookynet.module[0].nonlocal_interaction.resblock_k.residual.stack[0].linear2.weight - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_k.residual.stack[0].linear2.bias = \ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_k.residual.stack[0].linear2.bias = \ ref_spookynet.module[0].nonlocal_interaction.resblock_k.residual.stack[0].linear2.bias - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_k.activation.alpha = \ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_k.activation.alpha = \ ref_spookynet.module[0].nonlocal_interaction.resblock_k.activation.alpha - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_k.activation.beta = ref_spookynet.module[ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_k.activation.beta = ref_spookynet.module[ 0].nonlocal_interaction.resblock_k.activation.beta - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_k.linear.weight = ref_spookynet.module[ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_k.linear.weight = ref_spookynet.module[ 0].nonlocal_interaction.resblock_k.linear.weight - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_k.linear.bias = ref_spookynet.module[ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_k.linear.bias = ref_spookynet.module[ 0].nonlocal_interaction.resblock_k.linear.bias - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_v.residual.stack[0].activation1.alpha = \ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_v.residual.stack[0].activation1.alpha = \ ref_spookynet.module[0].nonlocal_interaction.resblock_v.residual.stack[0].activation1.alpha - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_v.residual.stack[0].activation1.beta = \ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_v.residual.stack[0].activation1.beta = \ ref_spookynet.module[0].nonlocal_interaction.resblock_v.residual.stack[0].activation1.beta - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_v.residual.stack[0].linear1.weight = \ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_v.residual.stack[0].linear1.weight = \ ref_spookynet.module[0].nonlocal_interaction.resblock_v.residual.stack[0].linear1.weight - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_v.residual.stack[0].linear1.bias = \ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_v.residual.stack[0].linear1.bias = \ ref_spookynet.module[0].nonlocal_interaction.resblock_v.residual.stack[0].linear1.bias - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_v.residual.stack[0].activation2.alpha = \ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_v.residual.stack[0].activation2.alpha = \ ref_spookynet.module[0].nonlocal_interaction.resblock_v.residual.stack[0].activation2.alpha - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_v.residual.stack[0].activation2.beta = \ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_v.residual.stack[0].activation2.beta = \ ref_spookynet.module[0].nonlocal_interaction.resblock_v.residual.stack[0].activation2.beta - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_v.residual.stack[0].linear2.weight = \ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_v.residual.stack[0].linear2.weight = \ ref_spookynet.module[0].nonlocal_interaction.resblock_v.residual.stack[0].linear2.weight - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_v.residual.stack[0].linear2.bias = \ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_v.residual.stack[0].linear2.bias = \ ref_spookynet.module[0].nonlocal_interaction.resblock_v.residual.stack[0].linear2.bias - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_v.activation.alpha = \ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_v.activation.alpha = \ ref_spookynet.module[0].nonlocal_interaction.resblock_v.activation.alpha - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_v.activation.beta = ref_spookynet.module[ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_v.activation.beta = ref_spookynet.module[ 0].nonlocal_interaction.resblock_v.activation.beta - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_v.linear.weight = ref_spookynet.module[ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_v.linear.weight = ref_spookynet.module[ 0].nonlocal_interaction.resblock_v.linear.weight - spookynet.core_module.interaction_modules[0].nonlocal_interaction.resblock_v.linear.bias = ref_spookynet.module[ + spookynet.core_module.interaction_modules[0].nonlocal_interaction.resmlp_v.linear.bias = ref_spookynet.module[ 0].nonlocal_interaction.resblock_v.linear.bias spookynet.core_module.interaction_modules[0].residual_pre.stack[0].activation1.alpha = \ @@ -371,28 +368,28 @@ def test_forward(): spookynet.core_module.interaction_modules[0].residual_post.stack[0].linear2.bias = \ ref_spookynet.module[0].residual_post.stack[0].linear2.bias - spookynet.core_module.interaction_modules[0].resmlp_l.residual.stack[0].activation1.alpha = \ - ref_spookynet.module[0].resmlp_l.residual.stack[0].activation1.alpha - spookynet.core_module.interaction_modules[0].resmlp_l.residual.stack[0].activation1.beta = \ - ref_spookynet.module[0].resmlp_l.residual.stack[0].activation1.beta - spookynet.core_module.interaction_modules[0].resmlp_l.residual.stack[0].linear1.weight = \ - ref_spookynet.module[0].resmlp_l.residual.stack[0].linear1.weight - spookynet.core_module.interaction_modules[0].resmlp_l.residual.stack[0].linear1.bias = \ - ref_spookynet.module[0].resmlp_l.residual.stack[0].linear1.bias - spookynet.core_module.interaction_modules[0].resmlp_l.residual.stack[0].activation2.alpha = \ - ref_spookynet.module[0].resmlp_l.residual.stack[0].activation2.alpha - spookynet.core_module.interaction_modules[0].resmlp_l.residual.stack[0].activation2.beta = \ - ref_spookynet.module[0].resmlp_l.residual.stack[0].activation2.beta - spookynet.core_module.interaction_modules[0].resmlp_l.residual.stack[0].linear2.weight = \ - ref_spookynet.module[0].resmlp_l.residual.stack[0].linear2.weight - spookynet.core_module.interaction_modules[0].resmlp_l.residual.stack[0].linear2.bias = \ - ref_spookynet.module[0].resmlp_l.residual.stack[0].linear2.bias - spookynet.core_module.interaction_modules[0].resmlp_l.activation.alpha = ref_spookynet.module[ - 0].resmlp_l.activation.alpha - spookynet.core_module.interaction_modules[0].resmlp_l.activation.beta = ref_spookynet.module[ - 0].resmlp_l.activation.beta - spookynet.core_module.interaction_modules[0].resmlp_l.linear.weight = ref_spookynet.module[0].resmlp_l.linear.weight - spookynet.core_module.interaction_modules[0].resmlp_l.linear.bias = ref_spookynet.module[0].resmlp_l.linear.bias + spookynet.core_module.interaction_modules[0].resmlp.residual.stack[0].activation1.alpha = \ + ref_spookynet.module[0].resblock.residual.stack[0].activation1.alpha + spookynet.core_module.interaction_modules[0].resmlp.residual.stack[0].activation1.beta = \ + ref_spookynet.module[0].resblock.residual.stack[0].activation1.beta + spookynet.core_module.interaction_modules[0].resmlp.residual.stack[0].linear1.weight = \ + ref_spookynet.module[0].resblock.residual.stack[0].linear1.weight + spookynet.core_module.interaction_modules[0].resmlp.residual.stack[0].linear1.bias = \ + ref_spookynet.module[0].resblock.residual.stack[0].linear1.bias + spookynet.core_module.interaction_modules[0].resmlp.residual.stack[0].activation2.alpha = \ + ref_spookynet.module[0].resblock.residual.stack[0].activation2.alpha + spookynet.core_module.interaction_modules[0].resmlp.residual.stack[0].activation2.beta = \ + ref_spookynet.module[0].resblock.residual.stack[0].activation2.beta + spookynet.core_module.interaction_modules[0].resmlp.residual.stack[0].linear2.weight = \ + ref_spookynet.module[0].resblock.residual.stack[0].linear2.weight + spookynet.core_module.interaction_modules[0].resmlp.residual.stack[0].linear2.bias = \ + ref_spookynet.module[0].resblock.residual.stack[0].linear2.bias + spookynet.core_module.interaction_modules[0].resmlp.activation.alpha = ref_spookynet.module[ + 0].resblock.activation.alpha + spookynet.core_module.interaction_modules[0].resmlp.activation.beta = ref_spookynet.module[ + 0].resblock.activation.beta + spookynet.core_module.interaction_modules[0].resmlp.linear.weight = ref_spookynet.module[0].resblock.linear.weight + spookynet.core_module.interaction_modules[0].resmlp.linear.bias = ref_spookynet.module[0].resblock.linear.bias spookynet.core_module.energy_and_charge_readout.weight = ref_spookynet.output.weight ref_spookynet.train() @@ -557,9 +554,9 @@ def test_spookynet_interaction_module_against_reference(): mf_param[:] = ref_param assert len(list(mf_spookynet_interaction_module.resblock.named_parameters())) == len( - list(ref_spookynet_interaction_module.resmlp_l.named_parameters())) + list(ref_spookynet_interaction_module.resblock.named_parameters())) for (mf_name, mf_param), (ref_name, ref_param) in zip(mf_spookynet_interaction_module.resblock.named_parameters(), - ref_spookynet_interaction_module.resmlp_l.named_parameters()): + ref_spookynet_interaction_module.resblock.named_parameters()): print(f"{mf_name=} {ref_name=}") if not torch.equal(mf_param, ref_param): print(f"{mf_param=} {ref_param=}") From 57e7b498aea2afca59baf2998ce6fa23f9000211 Mon Sep 17 00:00:00 2001 From: Arnav Nagle Date: Wed, 21 Aug 2024 06:47:25 -0700 Subject: [PATCH 78/78] Update test so that it compares the results of running the modelforge implementation on a batched methane with the reference implementation on a single methane --- modelforge/tests/test_spookynet.py | 52 ++++++++++++++++-------------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/modelforge/tests/test_spookynet.py b/modelforge/tests/test_spookynet.py index c905dfc6..2de93833 100644 --- a/modelforge/tests/test_spookynet.py +++ b/modelforge/tests/test_spookynet.py @@ -59,25 +59,29 @@ def test_forward(): ).double() single_methane = setup_single_methane_input()["modelforge_methane_input"] - model_input = NNPInput( + double_methane = NNPInput( atomic_numbers=torch.cat([single_methane.atomic_numbers] * 2, dim=0), positions=torch.cat([single_methane.positions] * 2, dim=0), atomic_subsystem_indices=torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]), total_charge=torch.cat([single_methane.total_charge] * 2, dim=0), ) print(f"{torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], dtype=torch.int64).dtype=}") - print(f"test: {model_input.atomic_subsystem_indices.dtype=}") - ic(model_input) - model_input.positions = model_input.positions.double() - model_input.total_charge = model_input.total_charge.double() - - print(f"test: {model_input.atomic_subsystem_indices.dtype=}") - spookynet.input_preparation._input_checks(model_input) - - print(f"test: {model_input.atomic_subsystem_indices.dtype=}") - pairlist_output = spookynet.input_preparation.prepare_inputs(model_input) - print(f"test: {model_input.atomic_subsystem_indices.dtype=}") - calculated_results = spookynet.core_module.forward(model_input, pairlist_output) + print(f"test: {double_methane.atomic_subsystem_indices.dtype=}") + ic(double_methane) + single_methane.positions = single_methane.positions.double() + single_methane.total_charge = single_methane.total_charge.double() + double_methane.positions = double_methane.positions.double() + double_methane.total_charge = double_methane.total_charge.double() + + print(f"test: {double_methane.atomic_subsystem_indices.dtype=}") + spookynet.input_preparation._input_checks(single_methane) + single_pairlist_output = spookynet.input_preparation.prepare_inputs(single_methane) + + spookynet.input_preparation._input_checks(double_methane) + print(f"test: {double_methane.atomic_subsystem_indices.dtype=}") + double_pairlist_output = spookynet.input_preparation.prepare_inputs(double_methane) + print(f"test: {double_methane.atomic_subsystem_indices.dtype=}") + calculated_results = spookynet.core_module.forward(double_methane, double_pairlist_output) ref_spookynet = RefSpookyNet( num_features=config["potential"]["core_parameter"]["number_of_atom_features"], @@ -395,19 +399,19 @@ def test_forward(): ref_spookynet.train() # TODO: how are multiple systems passed into the reference SpookyNet - print(f"test: {model_input.atomic_subsystem_indices.dtype=}") - reference_calculated_results = ref_spookynet( - model_input.atomic_numbers, - model_input.total_charge, - (model_input.positions * unit.nanometer).m_as(unit.angstrom), - pairlist_output.pair_indices[0], - pairlist_output.pair_indices[1], - batch_seg=model_input.atomic_subsystem_indices.long(), - num_batch=2, + print(f"test: {double_methane.atomic_subsystem_indices.dtype=}") + energy, forces, dipole, f, ea, qa, ea_rep, ea_ele, ea_vdw, pa, c6 = ref_spookynet( + single_methane.atomic_numbers, + single_methane.total_charge, + (single_methane.positions * unit.nanometer).m_as(unit.angstrom), + single_pairlist_output.pair_indices[0], + single_pairlist_output.pair_indices[1], + batch_seg=None, + num_batch=1, ) - ic(calculated_results.keys()) - ic(reference_calculated_results) + ic(calculated_results["per_atom_energy"]) + ic(ea) def test_spookynet_forward(single_batch_with_batchsize_64, model_parameter):