Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to AIMNet2 architecture for radial and vector embedding #307

Merged
merged 16 commits into from
Nov 7, 2024

Conversation

wiederm
Copy link
Member

@wiederm wiederm commented Nov 4, 2024

Update to AIMNet2 architecture for radial and vector embedding. Updated the wiki entry.

Summary of changes:

This PR introduces significant improvements to the AimNet2 architecture, specifically in handling radial and vector contributions. The original implementation was limited in capturing angular dependencies and omitted clear aggregation and transformations for certain atomic interactions. These changes aim to enhance model expressiveness and bring it closer to the (somewhat vague) description in the original paper.

Key Modifications:

  • Refinement of Radial and Vector Features: gs and gv tensors are introduced to represent radial and vector (angular) symmetry functions, respectively. The radial basis (gs) captures pairwise distances, while gv encodes directional vectors for atom pairs.
  • Vector Contribution Computation: using an Einstein summation operation (torch.einsum), gv and an agh transformation matrix capture the vector (angular) contributions. This setup adds expressiveness by handling directional information, aligning the model closer with angular-dependent symmetry functions. The operation structure: torch.einsum("pa, pdg, agh -> phd", a_j, gv, agh) ensures the features align over vector components.
  • Enhanced Broadcasting and Aggregation: to match dimensions, radial and vector tensors undergo explicit broadcasting (via .unsqueeze()). This step avoids mismatches and clarifies the contribution structure.
  • Aggregated Contributions: output radial and vector contributions for each pair are aggregated per atom using index_add_ to collect contributions correctly.
  • Scaling with f Factor: a scaling factor f is applied to delta_q (charge updates), which is introduced to improve control over charge accumulation during message-passing steps. This approach allows initializing and incrementally updating charges in a more controlled manner.
  • Unified Message Module: each message-passing module now utilizes an MLP with outputs split into delta_q, f, and delta_a, making the function of each component explicit. This organization supports independent updates for atomic embeddings (delta_a) and partial charges (delta_q), providing finer control over per-atom features and charge states.
  • Structural Adjustments for Message Passing: the first interaction module updates only atomic embeddings, while subsequent modules also modify partial charges. The message composition in later modules combines features from both embeddings and charges, producing a richer atomic feature set.

Paper Vague on Equations
The original paper is vague regarding details, especially for the equations governing vector contributions and aggregation. This implementation assumes certain interpretations based on the partial information given, and are not easily to decode. This is therefore not a literal implementation of AimNet2, but only a interpretation.

@wiederm wiederm self-assigned this Nov 5, 2024
@wiederm wiederm added the refactoring Improve the quality of the code without functional changes label Nov 5, 2024
@wiederm wiederm requested a review from MarshallYan November 5, 2024 19:52
@codecov-commenter
Copy link

codecov-commenter commented Nov 5, 2024

Codecov Report

Attention: Patch coverage is 98.24561% with 1 line in your changes missing coverage. Please review.

Project coverage is 85.52%. Comparing base (65d2b43) to head (9115792).
Report is 9 commits behind head on main.

Additional details and impacted files

@wiederm wiederm merged commit cf5b7c3 into main Nov 7, 2024
5 of 6 checks passed
@wiederm wiederm deleted the dev-aimnet2-new branch November 7, 2024 14:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
refactoring Improve the quality of the code without functional changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants