Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement SpookyNet potential #133

Open
wants to merge 95 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
95 commits
Select commit Hold shift + click to select a range
faf9c9b
Copy SchNet implementation and rename to SpookyNet
ArnNag May 29, 2024
c3575af
Merge branch 'main' into spookynet
ArnNag May 29, 2024
6536b59
Add spookynet to test environment and list of models. Rename SAKE inp…
ArnNag May 31, 2024
359f632
Merge branch 'main' of github.com:choderalab/modelforge into spookynet
ArnNag Jun 3, 2024
a6e8ce5
Copy SpookyNet interaction module code from reference implementation.…
ArnNag Jun 5, 2024
06d85ca
Merge branch 'main' into spookynet
ArnNag Jun 6, 2024
f3448c6
Merge remote-tracking branch 'choderalab/main' into spookynet
ArnNag Jun 11, 2024
7d858b4
Merge branch 'main' into spookynet
ArnNag Jun 14, 2024
fe5cd04
Hard code one batch (relevant within attention). Rename variables to …
ArnNag Jun 20, 2024
2747fba
More changes. Starting to implement tests
ArnNag Jun 25, 2024
b4f1b72
Merge remote-tracking branch 'origin/spookynet' into spookynet
ArnNag Jun 25, 2024
4a5e072
Implement equivalence test for SpookyNet interaction block. First ele…
ArnNag Jun 26, 2024
16538b6
Implement SpookyNet radial basis function. Test failing (probably due…
ArnNag Jun 26, 2024
9a210b8
Begin refactoring radial basis functions
ArnNag Jun 27, 2024
2930e44
Merge branch 'main' into refactor_rbf
ArnNag Jun 27, 2024
6e7b10e
Update RadialBasisFunction comment
ArnNag Jun 27, 2024
b6f1149
Fix references to radial basis centers and scale factors
ArnNag Jun 27, 2024
7757010
Continue refactoring radial basis functions
ArnNag Jun 27, 2024
b4ea01f
Merge remote-tracking branch 'origin/refactor_rbf' into refactor_rbf
ArnNag Jun 27, 2024
e4af371
Fix SchNet RBF
ArnNag Jun 28, 2024
6398cf8
Fix PhysNet and ANI RBF
ArnNag Jun 28, 2024
7e375e8
Trying to fix SAKE but failing
ArnNag Jun 28, 2024
e48a395
Fix SAKE test
ArnNag Jul 2, 2024
fdfa4d5
Remove unnecessary test
ArnNag Jul 2, 2024
c4b24b4
Refactor prefactor
ArnNag Jul 2, 2024
25a5ac9
Update comment for nondimensionalization
ArnNag Jul 2, 2024
c0280fb
Fix SAKE and ANI bugs
ArnNag Jul 2, 2024
a3bfc0e
Fix test radial symmetry function
ArnNag Jul 2, 2024
9af3caf
Merge branch 'main' into refactor_rbf
ArnNag Jul 2, 2024
fc11a29
Merge branch 'refs/heads/refactor_rbf' into spookynet
ArnNag Jul 2, 2024
6b9f171
Update utils.py
wiederm Jul 3, 2024
a4f41b9
typo fix
wiederm Jul 3, 2024
0e492d7
Make RadialBasisFunctionCore inherit from nn.Module
ArnNag Jul 5, 2024
bb387c5
Remove @staticmethod decorator
ArnNag Jul 5, 2024
5f35b72
Try to implement Bernstein polynomials
ArnNag Jul 6, 2024
22acdc7
Add shape assertion in RadialBasisFunctionCore. Remove unnecessary un…
ArnNag Jul 6, 2024
cb28ef1
Remove print statement
ArnNag Jul 6, 2024
1e88783
Remove print statements in SAKE
ArnNag Jul 6, 2024
5b0e9bc
Fix SAKE RBF test
ArnNag Jul 6, 2024
c0ff49e
Merge branch 'main' into refactor_rbf
ArnNag Jul 9, 2024
addc502
Refactor exponential Bernstein polynomials
ArnNag Jul 9, 2024
1f11f3a
Fix SchNet tests
ArnNag Jul 9, 2024
2109f3c
Fix spk tests
ArnNag Jul 9, 2024
d08c9c6
Clean spk test
ArnNag Jul 9, 2024
0f5b1fc
Merge branch 'refs/heads/refactor_rbf' into spookynet
ArnNag Jul 9, 2024
0ff04f7
Working on spookynet Bernstein polynomials. Weird shape assertion fail.
ArnNag Jul 9, 2024
3da2330
Broadcast to number of radial basis functions in nondimensionalizatio…
ArnNag Jul 9, 2024
e24b6fa
update toml
wiederm Jul 10, 2024
23ea05f
transition to new version of postprocessing
wiederm Jul 10, 2024
9632a7d
updated toml
wiederm Jul 10, 2024
dc63f13
update names, intial postprocessing implementation
Jul 10, 2024
645eb35
working prototype
wiederm Jul 10, 2024
3160923
update all toml files and all potential output names
wiederm Jul 11, 2024
77756f9
update input signature and toml files
wiederm Jul 11, 2024
7f841a9
remove legacy code
wiederm Jul 11, 2024
93f0f28
update
wiederm Jul 11, 2024
a0dde31
update loss docstrings
wiederm Jul 11, 2024
aeec947
fixing tests
wiederm Jul 11, 2024
d7e819c
fixing tests and names
wiederm Jul 11, 2024
cab0ed9
fix test
wiederm Jul 11, 2024
9b6086f
Merge branch 'main' into ref-postprocessing-and-loss
wiederm Jul 11, 2024
03a5930
error calculation tests
wiederm Jul 11, 2024
9293cba
Merge branch 'ref-postprocessing-and-loss' of https://github.com/chod…
wiederm Jul 11, 2024
90d4b66
update
wiederm Jul 11, 2024
62311c0
add flax
wiederm Jul 11, 2024
909238a
update yaml
wiederm Jul 11, 2024
629df00
update
wiederm Jul 11, 2024
5cf4314
Add comments for todos
ArnNag Jul 11, 2024
bc78257
Merge remote-tracking branch 'upstream/ref-postprocessing-and-loss' i…
ArnNag Jul 11, 2024
f522208
Fix schnet test
ArnNag Jul 11, 2024
eda75b9
Fix test physnet compare representation
ArnNag Jul 11, 2024
3e0e4d7
Update SpookyNet for postprocessing
ArnNag Jul 11, 2024
c4d61fe
Working through forward test
ArnNag Jul 12, 2024
f7b132d
Fix bug that resulted in incorrect dir_ij
ArnNag Jul 12, 2024
9a52c01
Replace * with einsum
ArnNag Jul 12, 2024
2ec3f3f
Replace more operations with einsum
ArnNag Jul 12, 2024
5a32fe4
Replace one more operation with torch.einsum. Fix axis label for n
ArnNag Jul 12, 2024
6ce39e4
Add explicit broadcast. Reformat code.
ArnNag Jul 12, 2024
b7a2b28
Fix ini_alpha with units
ArnNag Jul 12, 2024
731bd7b
Trying to implement embeddings
ArnNag Jul 15, 2024
71c3fea
More changes
ArnNag Jul 16, 2024
e548bf1
More changes
ArnNag Jul 20, 2024
5ea8988
Fix atomic embedding
ArnNag Jul 23, 2024
b61649e
Merge branch 'refs/heads/main' into spookynet
ArnNag Jul 25, 2024
5a60c2c
Fix merge conflict issues
ArnNag Jul 25, 2024
e4dd91c
Remove unnecessarily hard-coded arrays in the model implementation an…
ArnNag Jul 25, 2024
5e00a6d
Changes to tests and model
ArnNag Jul 25, 2024
6ad3712
Copy parameters
ArnNag Jul 25, 2024
0539ed3
Add line breaks between different blocks when copying parameters
ArnNag Jul 25, 2024
b613586
Changes to tests and model
ArnNag Jul 26, 2024
10c9914
More changes
ArnNag Jul 30, 2024
d170b79
Update docstrings
ArnNag Aug 13, 2024
5f9666d
Fix naming resblock -> resmlp
ArnNag Aug 13, 2024
57e7b49
Update test so that it compares the results of running the modelforge…
ArnNag Aug 21, 2024
b28d660
Merge branch 'main' into spookynet
wiederm Nov 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions modelforge/potential/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
SAKEParameters,
SchNetParameters,
TensorNetParameters,
SpookyNetParameters,
)
from .processing import FromAtomToMoleculeReduction
from .representation import (
Expand Down Expand Up @@ -63,6 +64,7 @@ class _Implemented_NNPs(Enum):
PAINN = PaiNNCore
SAKE = SAKECore
AIMNET2 = AimNet2Core
SPOOKYNET = SpookyNet

@classmethod
def get_neural_network_class(cls, neural_network_name: str):
Expand Down
1,193 changes: 1,193 additions & 0 deletions modelforge/potential/spookynet.py

Large diffs are not rendered by default.

92 changes: 92 additions & 0 deletions modelforge/potential/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,13 +240,105 @@ def forward(self, x: torch.Tensor):

Returns:
-----------

torch.Tensor
Shifted soft-plus of the input.
"""

return F.softplus(x) - self.log_2



class ExponentialBernsteinPolynomialsCore(RadialBasisFunctionCore):
"""
Taken from SpookyNet.
Radial basis functions based on exponential Bernstein polynomials given by:
b_{v,n}(x) = (n over v) * exp(-alpha*x)**v * (1-exp(-alpha*x))**(n-v)
(see https://en.wikipedia.org/wiki/Bernstein_polynomial)
Here, n = num_basis_functions-1 and v takes values from 0 to n. This
implementation operates in log space to prevent multiplication of very large
(n over v) and very small numbers (exp(-alpha*x)**v and
(1-exp(-alpha*x))**(n-v)) for numerical stability.
NOTE: There is a problem for x = 0, as log(-expm1(0)) will be log(0) = -inf.
This itself is not an issue, but the buffer v contains an entry 0 and
0*(-inf)=nan. The correct behaviour could be recovered by replacing the nan
with 0.0, but should not be necessary because issues are only present when
r = 0, which will not occur with chemically meaningful inputs.

Arguments:
number_of_radial_basis_functions (int):
Number of radial basis functions.
x = infinity.
"""

def __init__(self, number_of_radial_basis_functions: int):
super().__init__(number_of_radial_basis_functions)
logfactorial = np.zeros(number_of_radial_basis_functions)
for i in range(2, number_of_radial_basis_functions):
logfactorial[i] = logfactorial[i - 1] + np.log(i)
v = np.arange(0, number_of_radial_basis_functions)
n = (number_of_radial_basis_functions - 1) - v
logbinomial = logfactorial[-1] - logfactorial[v] - logfactorial[n]
# register buffers and parameters
dtype = torch.float64 # TODO: make this a parameter
self.logc = torch.tensor(logbinomial, dtype=dtype)
self.n = torch.tensor(n, dtype=dtype)
self.v = torch.tensor(v, dtype=dtype)

def forward(self, nondimensionalized_distances: torch.Tensor) -> torch.Tensor:
"""
Evaluates radial basis functions given distances
N: Number of input values.
num_basis_functions: Number of radial basis functions.

Arguments:
nondimensionalized_distances (FloatTensor [N]):
Input distances.

Returns:
rbf (FloatTensor [N, num_basis_functions]):
Values of the radial basis functions for the distances r.
"""
assert nondimensionalized_distances.ndim == 2
assert (
nondimensionalized_distances.shape[1]
== self.number_of_radial_basis_functions
)
x = (
self.logc
+ (self.n + 1) * nondimensionalized_distances
+ self.v * torch.log(-torch.expm1(nondimensionalized_distances))
)

return torch.exp(x)

class ExponentialBernsteinRadialBasisFunction(RadialBasisFunction):

def __init__(self,
number_of_radial_basis_functions: int,
ini_alpha: unit.Quantity = 2.0 * unit.bohr,
dtype=torch.int64):
"""
ini_alpha (float):
Initial value for scaling parameter alpha (alpha here is the reciprocal of alpha in the paper. The original
default is 0.5/bohr, so we use 2 bohr).
"""
super().__init__(
ExponentialBernsteinPolynomialsCore(number_of_radial_basis_functions),
trainable_prefactor=False,
dtype=dtype,
)
self.register_parameter("alpha", nn.Parameter(torch.tensor(ini_alpha.m_as(unit.nanometer))))

def nondimensionalize_distances(self, d_ij: torch.Tensor) -> torch.Tensor:
return -(
d_ij.broadcast_to(
(len(d_ij), self.radial_basis_function.number_of_radial_basis_functions)
)
/ self.alpha
)


def pair_list(
atomic_subsystem_indices: torch.Tensor,
only_unique_pairs: bool = False,
Expand Down
18 changes: 18 additions & 0 deletions modelforge/tests/data/potential_defaults/spookynet.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[potential]
model_name = "SpookyNet"

[potential.core_parameter]
max_Z = 87
number_of_atom_features = 64
number_of_radial_basis_functions = 16
cutoff = "5.291772105638412 angstrom" # 10 a0
number_of_interaction_modules = 3
number_of_residual_blocks = 1

[potential.postprocessing_parameter]
[potential.postprocessing_parameter.per_atom_energy]
normalize = true
from_atom_to_molecule_reduction = true
keep_per_atom_property = true
[potential.postprocessing_parameter.general_postprocessing_operation]
calculate_molecular_self_energy = true
2 changes: 0 additions & 2 deletions modelforge/tests/test_schnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,6 @@ def test_compare_implementation_against_reference_implementation():
# test the implementation of the representation part of the PaiNN model
# ---------------------------------------- #
model = setup_schnet_model(1234).double()
# ---------------------------------------- #

# ------------------------------------ #
# reference values
# generated with schnetpack2.0
Expand Down
Loading
Loading