-
Notifications
You must be signed in to change notification settings - Fork 585
Description
Is this a new feature, an improvement, or a change to existing functionality?
Improvement
How would you describe the priority of this feature request
Low (would be nice)
Please provide a clear description of problem you would like to solve.
Following suggestion form @vpuri3 can help improve the accuracy of GAFLARE attention.
- Try MLP(C -> 4C -> GeLU -> C) in place of linear_layer used for self.self_k, self.self_v in GAFLARE.
- Try cross attention for context outlined on page 71 (5.1.2 Aim 1(b): conditioning mechanism for dynamic PDE surrogates) of this document: https://drive.google.com/file/d/1SNDjQ0gMSZmv0jg49S-risEoDiwE63aY/view?usp=sharing
Reason for 1:
@vpuri3: I've found that using a more expressive projection here really helps performance on PDE problems.
The tradeoff here is described in Appendix F under heading "Tradeoff between query dynamics and key/value expressivity" in the paper: https://arxiv.org/pdf/2508.12594.
For PDE problems, I've found that replacing FFN type layers (C -> 4C -> GeLU -> C) with deeper but narrower MLPs can help because the mapping is often smoother / more “function-approximation-like,” and gains come from expressive feature transforms more than from content-addressable routing/memorization.
Here's the full model definition I used in the experiments in the paper:
https://github.com/vpuri3/FLARE.py/blob/master/pdebench/models/flare.py
I understand that deep KV projections would increase parameter counts. To compensate for that, we have validated that FLARE performs at par with other models at smaller hidden sizes (C=64 for FLARE outperforms C=128 for transolver).
Describe any alternatives you have considered
No response