GeoFLARE: added GALE_FA, an alternate attention to GALE, for GeoTransolver#1405
GeoFLARE: added GALE_FA, an alternate attention to GALE, for GeoTransolver#1405dakhare-creator wants to merge 20 commits intoNVIDIA:mainfrom
Conversation
Greptile OverviewGreptile SummaryThis PR adds GAFLARE (Geometry-Aware FLARE), an alternative attention mechanism to GALE for the GeoTransolver model. Key Changes:
Critical Issues Found:
Minor Issues:
Important Files Changed
|
| self.use_te = False # te will disable FlashAttention for different size of q and k | ||
| self.scale = 1. #dim_head**-0.5 |
There was a problem hiding this comment.
use_te parameter ignored - hardcoded to False on line 121, making the use_te constructor parameter ineffective.
| self.use_te = False # te will disable FlashAttention for different size of q and k | |
| self.scale = 1. #dim_head**-0.5 | |
| super().__init__() | |
| self.use_te = use_te and TE_AVAILABLE |
There was a problem hiding this comment.
raise ValueError() added.
| **kwargs, | ||
| ): | ||
| self.use_te = False # te will disable FlashAttention for different size of q and k | ||
| self.scale = 1. #dim_head**-0.5 |
There was a problem hiding this comment.
self.scale set to 1.0 but commented code suggests it should be dim_head**-0.5 for proper attention scaling. Current implementation may affect attention quality.
| self.scale = 1. #dim_head**-0.5 | |
| self.scale = dim_head**-0.5 |
There was a problem hiding this comment.
Updated: self.scale = 1. # FLARE scale is 1.0
FLARE code uses scale 1.0
| context = torch.randn(2, 8, 64, 32).to(device) | ||
| gaflare = GAFLARE(dim=256, heads=8, dim_head=32, context_dim=32).to(device) | ||
| outputs = gaflare(x, context) | ||
| print(outputs[0].shape) No newline at end of file |
There was a problem hiding this comment.
Missing newline at end of file
| print(outputs[0].shape) | |
| print(outputs[0].shape) |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| :class:`GALE` : Origional GeoTransolver GALE attention class. | ||
| :class:`GALE_block` : Transformer block using GAFLARE attention. |
There was a problem hiding this comment.
Typo: "Origional" should be "Original"
| :class:`GALE` : Origional GeoTransolver GALE attention class. | |
| :class:`GALE_block` : Transformer block using GAFLARE attention. | |
| :class:`GALE` : Original GeoTransolver GALE attention class. | |
| :class:`GAFLARE_block` : Transformer block using GAFLARE attention. |
There was a problem hiding this comment.
Origional -> Original
Additional Comments (1)
ATTENTION_CLASSES = {
"GALE": GALE,
"GAFLARE": GAFLARE,
}
if attention_type not in ATTENTION_CLASSES:
raise ValueError(f"Unknown attention_type: {attention_type}")
self.Attn = ATTENTION_CLASSES[attention_type](
hidden_dim,
heads=num_heads,
# ... rest of arguments
) |
coreyjadams
left a comment
There was a problem hiding this comment.
Hi @dakhare-creator - thanks for opening this. Overall it seems reasonable. Good to put it in experimental. I made a couple comments in the code, but also:
- it'd be unusual for use to publish a "GeoFlare" model without also making available "Flare". Can you add standard FLARE attention to the experimental/nn folder so we can do that too?
- We can also probably easily create a flare.py itself for the model.
- How much overlap in the math and attention blocks is there between PhysicsAttention and FLARE Attention? PhysicsAttention can be used on 2d and 3d data, could this? I strongly suspect yes. Can we make this modular so we can reuse as much as possible between attention layers?
Once we're ready, let's also invite the FLARE authors to take a look.
| # GALE attention layer | ||
| self.Attn = GALE( | ||
| if attention_type in globals(): | ||
| self.Attn = globals()[attention_type]( |
There was a problem hiding this comment.
Let's avoid using globals() here like this I think it's not needed. You can just do a match to get the attention type?
There was a problem hiding this comment.
Updated with match.
| if __name__ == "__main__": | ||
| device = "cuda" if torch.cuda.is_available() else "cpu" | ||
| x = (torch.randn(2, 100, 256).to(device),) | ||
| context = torch.randn(2, 8, 64, 32).to(device) | ||
| gaflare = GAFLARE(dim=256, heads=8, dim_head=32, context_dim=32).to(device) | ||
| outputs = gaflare(x, context) | ||
| print(outputs[0].shape) |
There was a problem hiding this comment.
Let's remove this from the model code.
| TE_AVAILABLE = check_version_spec("transformer_engine", "0.1.0", hard_fail=False) | ||
| if TE_AVAILABLE: | ||
| import transformer_engine.pytorch as te |
There was a problem hiding this comment.
There is a newer, simpler syntax we can use here with OptionalImport.
There was a problem hiding this comment.
Added following import to gaflare.py and gale.py:
from physicsnemo.core.version_check import OptionalImport
te = OptionalImport("transformer_engine.pytorch", "0.1.0")
| import transformer_engine.pytorch as te | ||
|
|
||
|
|
||
| class GAFLARE(nn.Module): |
There was a problem hiding this comment.
Can we revisit this name?
There was a problem hiding this comment.
Yes for sure! Please let me know if you have any suggestion.
update main
update geoflare
| See Also | ||
| -------- | ||
| :class:`GALE` : Original GeoTransolver GALE attention class. | ||
| :class:`GALE_block` : Transformer block using GAFLARE attention. |
There was a problem hiding this comment.
Incorrect cross-reference
There was a problem hiding this comment.
GALE_block calls GAFLARE and GALE is alternative attention mechanism.
Should we change it as -
See Also
--------
:class:`GALE` : Original GeoTransolver GALE attention class.
:class:`GALE_block` : Transformer block **that calls GALE or GAFLARE** attention.
| # Linear projections for cross-attention | ||
| self.cross_q = linear_layer(dim_head, dim_head) | ||
| self.cross_k = linear_layer(context_dim, dim_head) | ||
| self.cross_v = linear_layer(context_dim, dim_head) |
There was a problem hiding this comment.
If context_dim=0, these layers are silently created, and will fail later (or produce garbage) if a non-empty context tensor is passed. A better design would skip creating the corss-attention layers entirely when context_dim=0. Or you can explicitly raise an error:
if context is not None and self.context_dim == 0:
raise ValueError(...)
There was a problem hiding this comment.
Added if context_dim > 0: to skip creating the cross-attention layers.
| use_te: bool = True, | ||
| context_dim: int = 0, | ||
| ): | ||
| self.use_te = False # te will disable FlashAttention for different size of q and k |
There was a problem hiding this comment.
I suggest we explicitly raise an error when use_te=True, instead of silently setting it to False.
There was a problem hiding this comment.
Added
if use_te:
raise ValueError((
"GAFLARE does not support Transformer Engine backend. "
"Use use_te=False; TE disables FlashAttention for differing q/k sizes in FLARE attention."
)
vpuri3
left a comment
There was a problem hiding this comment.
Hi @dakhare-creator, I've left several comments in the code. Implementing FLARE straight on a geom problem can be error prone. I'd recommend first validating FLARE on one of the paper's benchmark problems
https://github.com/vpuri3/FLARE.py/blob/master/pdebench/models/flare.py
and then exposing the relevant layers to the geometry solvers interface.
| "Use use_te=False; TE disables FlashAttention for differing q/k sizes in FLARE attention." | ||
| ) | ||
| self.use_te = use_te | ||
| self.scale = 1. # FLARE scale is 1.0 |
There was a problem hiding this comment.
while 1 is the correct attn scale for dim_head = 8, this attn scale may not scale well for greater head dims. I'd recommend either exposing this setting to the user or setting it automatically like:
`self.scale = 1 if self.dim_head <= 8 else (self.dim_head ** -0.5)
There was a problem hiding this comment.
Thanks for the recommendation. Added!
| dim : int | ||
| Input dimension of the features. | ||
| heads : int, optional | ||
| Number of attention heads. Default is 8. | ||
| dim_head : int, optional | ||
| Dimension of each attention head. Default is 64. |
There was a problem hiding this comment.
Is it the case that dim = heads * dim_head? If so maybe an assert or warnings would be good in case the user. passes in an inconsistent setting.
There was a problem hiding this comment.
In this implementation, dim = heads * dim_head is not the case.
We explicitly define inner_dim = heads * dim_head and the dataflow looks like
dim -> inner_dim -> [heads, dim_head] -> inner_dim -> dim
| # Use PyTorch's scaled dot-product attention | ||
| z = [F.scaled_dot_product_attention(_G, _k, _v, scale=1.0) for _G, _k, _v in zip(G, k, v)] | ||
| self_attention = [F.scaled_dot_product_attention(_k, _G, _z, scale=1.0) for _k, _G, _z in zip(k, G, z)] |
| self.self_k = linear_layer(dim_head, dim_head) | ||
| self.self_v = linear_layer(dim_head, dim_head) |
There was a problem hiding this comment.
I've found that using a more expressive projection here really helps performance on PDE problems.
The tradeoff here is described in Appendix F under heading "Tradeoff between query dynamics and key/value expressivity" in the paper: https://arxiv.org/pdf/2508.12594.
For PDE problems, I've found that replacing FFN type layers (C -> 4C -> GeLU -> C) with deeper but narrower MLPs can help because the mapping is often smoother / more “function-approximation-like,” and gains come from expressive feature transforms more than from content-addressable routing/memorization.
Here's the full model definition I used in the experiments in the paper:
https://github.com/vpuri3/FLARE.py/blob/master/pdebench/models/flare.py
I understand that deep KV projections would increase parameter counts. To compensate for that, we have validated that FLARE performs at par with other models at smaller hidden sizes (C=64 for FLARE outperforms C=128 for transolver).
There was a problem hiding this comment.
Thank you for the suggestion. I’ll test this and include it in the next PR. I’ve also opened an issue to track the improvement: #1440
| ) for _cross_attention in cross_attention] | ||
| else: | ||
| cross_attention = [F.scaled_dot_product_attention(_q, k, v, scale=1.0) for _q in q] | ||
|
|
There was a problem hiding this comment.
Cross attention with FLARE has not been fully fleshed out. My plan for cross attention is outlined on page 71 (5.1.2 Aim 1(b): conditioning mechanism for dynamic PDE surrogates) of this document:
https://drive.google.com/file/d/1SNDjQ0gMSZmv0jg49S-risEoDiwE63aY/view?usp=sharing
There was a problem hiding this comment.
Thank you for the suggestion. I’ll test this and include it in the next PR. I’ve also opened an issue to track the improvement: #1440
|
/blossom-ci |
|
Hi @dakhare-creator - with the geometry encoded FLARE, does this also enable standard FLARE attention and model? |
Hi @coreyjadams, Yes if the context is None, the attention mechanism will be standard FLARE. |
| class FLARE(nn.Module): | ||
| r"""FLARE: Fast Low-rank Attention Routing Engine attention layer. | ||
| Adopted: | ||
| - FLARE attention: Fast Low-rank Attention Routing Engine | ||
| paper: https://arxiv.org/abs/2508.12594 | ||
|
|
There was a problem hiding this comment.
This should probably go into experimental/nn/flare_attention.py?
| self.Attn = FLARE( | ||
| hidden_dim, | ||
| spatial_shape=spatial_shape, | ||
| heads=num_heads, | ||
| dim_head=hidden_dim // num_heads, | ||
| dropout=dropout, | ||
| slice_num=slice_num, | ||
| n_global_queries=slice_num, | ||
| use_te=use_te, | ||
| plus=plus, |
There was a problem hiding this comment.
Does Flare support more multiple dimensions? How does it handle irregular vs. regular data?
There was a problem hiding this comment.
No, FLARE implemented here is only designed for point cloud and does not work for regular data. Maybe we should consider adding raise error for structure data input.
There was a problem hiding this comment.
correct - FLARE cannot make use of any connectivity information. If you feed a graph/mesh to flare, you have to only feed the unordered point cloud
| from physicsnemo.core.meta import ModelMetaData | ||
| from physicsnemo.core.module import Module | ||
| from physicsnemo.core.version_check import check_version_spec | ||
| from physicsnemo.experimental.models.transolver.flare import FLARE |
There was a problem hiding this comment.
transolver can't import from experimental.
Make flare.py an experimental model instead?
PhysicsNeMo Pull Request
Description
Checklist
Dependencies
Review Process
All PRs are reviewed by the PhysicsNeMo team before merging.
Depending on which files are changed, GitHub may automatically assign a maintainer for review.
We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.
AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.