Skip to content

Commit

Permalink
Merge pull request #307 from choderalab/dev-aimnet2-new
Browse files Browse the repository at this point in the history
Update to AIMNet2 architecture for radial and vector embedding
  • Loading branch information
wiederm authored Nov 7, 2024
2 parents fa4814f + 9115792 commit cf5b7c3
Show file tree
Hide file tree
Showing 9 changed files with 327 additions and 266 deletions.
497 changes: 237 additions & 260 deletions modelforge/potential/aimnet2.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion modelforge/potential/featurization.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,10 @@ def forward(self, data: NNPInput) -> torch.Tensor:
torch.Tensor
The featurized input data.
"""

atomic_numbers = data.atomic_numbers
categorial_embedding = self.atomic_number_embedding(atomic_numbers)
if torch.isnan(categorial_embedding).any():
raise ValueError("NaN values detected in categorial_embedding.")

for additional_embedding in self.embeddings:
categorial_embedding = additional_embedding(categorial_embedding, data)
Expand Down
3 changes: 3 additions & 0 deletions modelforge/potential/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class Featurization(BaseModel):
class ActivationFunctionName(CaseInsensitiveEnum):
ReLU = "ReLU"
CeLU = "CeLU"
GeLU = "GeLU"
Sigmoid = "Sigmoid"
Softmax = "Softmax"
ShiftedSoftplus = "ShiftedSoftplus"
Expand All @@ -74,6 +75,7 @@ class ActivationFunctionName(CaseInsensitiveEnum):
class ActivationFunctionParamsEnum(CaseInsensitiveEnum):
ReLU = "None"
CeLU = ActivationFunctionParamsAlpha
GeLU = "None"
Sigmoid = "None"
Softmax = "None"
ShiftedSoftplus = "None"
Expand Down Expand Up @@ -177,6 +179,7 @@ class CoreParameter(CoreParameterBase):
featurization: Featurization
predicted_properties: List[str]
predicted_dim: List[int]
number_of_vector_features: int
converted_units = field_validator("maximum_interaction_radius", mode="before")(
_convert_str_or_unit_to_unit_length
)
Expand Down
1 change: 1 addition & 0 deletions modelforge/potential/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ def scatter_softmax(
ACTIVATION_FUNCTIONS = {
"ReLU": nn.ReLU,
"CeLU": nn.CELU,
"GeLU": nn.GELU,
"Sigmoid": nn.Sigmoid,
"Softmax": nn.Softmax,
"ShiftedSoftplus": ShiftedSoftplus,
Expand Down
3 changes: 2 additions & 1 deletion modelforge/tests/data/potential_defaults/aimnet2.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ potential_name = "AimNet2"

[potential.core_parameter]
number_of_radial_basis_functions = 32
number_of_vector_features = 5
maximum_interaction_radius = "5.0 angstrom"
number_of_interaction_modules = 3
predicted_properties = ["per_atom_energy"]
predicted_dim = [1]

[potential.core_parameter.activation_function_parameter]
activation_function_name = "ShiftedSoftplus"
activation_function_name = "GeLU"

[potential.core_parameter.featurization]
properties_to_featurize = ['atomic_number']
Expand Down
75 changes: 74 additions & 1 deletion modelforge/tests/test_aimnet2.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,86 @@ def test_radial_symmetry_function_regression():
def test_forward(single_batch_with_batchsize, prep_temp_dir):
"""Test initialization of the AIMNet2 model."""
# read default parameters
aimnet = setup_potential_for_test("aimnet2", "training")
aimnet = setup_potential_for_test("aimnet2", "training", potential_seed=42)

assert aimnet is not None, "Aimnet model should be initialized."
batch = single_batch_with_batchsize(64, "QM9", str(prep_temp_dir))

y_hat = aimnet(batch.nnp_input)

assert y_hat is not None, "Aimnet model should be able to make predictions."

ref_per_system_energy = torch.tensor(
[
[-1.6222e00],
[-1.7771e-01],
[1.5974e-01],
[-1.2089e-02],
[-1.8864e-01],
[-2.7185e-01],
[-4.3214e00],
[-1.3357e00],
[-1.1657e00],
[-1.4146e00],
[-1.8898e00],
[-1.1582e00],
[-9.1212e00],
[-4.8285e00],
[-5.0907e00],
[-5.4467e00],
[-1.8100e00],
[-4.9845e00],
[-3.7676e00],
[-2.5988e00],
[-1.5824e01],
[-1.0948e01],
[-2.8324e-01],
[-4.5179e-01],
[-6.8437e-01],
[-3.1547e-01],
[-5.7387e-01],
[-4.6788e-01],
[-1.9818e00],
[-3.8900e00],
[-4.2745e00],
[-2.8107e00],
[-1.2960e00],
[-1.5892e00],
[-5.7663e00],
[-4.2937e00],
[-3.0977e00],
[-2.2906e00],
[-1.4034e01],
[-9.6701e00],
[-7.9657e00],
[-6.4762e00],
[-9.7999e00],
[-5.6619e00],
[-9.1679e00],
[-6.8304e00],
[-1.0582e01],
[-6.0419e00],
[-7.2018e00],
[-5.0521e00],
[-4.0748e00],
[-3.5285e00],
[-2.5017e00],
[-2.5237e01],
[-1.9461e01],
[-1.7413e00],
[-2.1273e00],
[-2.5887e00],
[-1.1963e00],
[-2.4938e00],
[-3.1271e00],
[-1.7812e00],
[-8.0866e00],
[-8.7542e00],
],
)

assert torch.allclose(y_hat["per_system_energy"], ref_per_system_energy, atol=1e-3)


@pytest.mark.xfail(raises=NotImplementedError)
def test_against_original_implementation():
Expand Down
3 changes: 2 additions & 1 deletion modelforge/tests/test_potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,7 @@ def test_equivariant_energies_and_forces(
# define the symmetry operations
translation, rotation, reflection = equivariance_utils
# define the tolerance
atol = 1e-3
atol = 1e-1

# ------------------- #
# start the test
Expand Down Expand Up @@ -1060,6 +1060,7 @@ def test_equivariant_energies_and_forces(
)[0]

rotate_reference = rotation(reference_forces)
print(rotation_forces, rotate_reference)
assert torch.allclose(
rotation_forces,
rotate_reference,
Expand Down
1 change: 0 additions & 1 deletion modelforge/train/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ class WandbConfig(ParametersBase):
notes: Optional[str]


# Move SchedulerConfig classes outside of TrainingParameters
class SchedulerConfigBase(ParametersBase):
"""
Base class for scheduler configurations
Expand Down
7 changes: 6 additions & 1 deletion modelforge/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1387,7 +1387,12 @@ def configure_optimizers(self):
bias_params = []

for name, param in self.potential.named_parameters():
if "weight" in name or "atomic_shift" in name or "gate" in name:
if (
"weight" in name
or "atomic_shift" in name
or "gate" in name
or "agh" in name
):
weight_params.append(param)
elif "bias" in name or "atomic_scale" in name:
bias_params.append(param)
Expand Down

0 comments on commit cf5b7c3

Please sign in to comment.