Skip to content

Commit

Permalink
schnet with einsum
Browse files Browse the repository at this point in the history
  • Loading branch information
wiederm committed Nov 6, 2024
1 parent c1430f2 commit 64e4cfa
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions modelforge/potential/schnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,11 @@ def forward(

# Generate interaction filters based on radial basis functions
W_ij = self.filter_network(f_ij.squeeze(1))
W_ij = W_ij * f_ij_cutoff # Shape: [n_pairs, number_of_filters]
W_ij = torch.einsum("nf,n->nf", W_ij, f_ij_cutoff.squeeze(-1))

# Perform continuous-filter convolution
x_j = atomic_embedding[idx_j]
x_ij = x_j * W_ij # Element-wise multiplication
x_ij = torch.einsum("nk,nk->nk", W_ij, x_j)

out = torch.zeros_like(atomic_embedding).scatter_add_(
0, idx_i.unsqueeze(-1).expand_as(x_ij), x_ij
Expand Down

0 comments on commit 64e4cfa

Please sign in to comment.