Skip to content

GeoFLARE: added GALE_FA, an alternate attention to GALE, for GeoTransolver#1405

Open
dakhare-creator wants to merge 20 commits intoNVIDIA:mainfrom
dakhare-creator:geoflare
Open

GeoFLARE: added GALE_FA, an alternate attention to GALE, for GeoTransolver#1405
dakhare-creator wants to merge 20 commits intoNVIDIA:mainfrom
dakhare-creator:geoflare

Conversation

@dakhare-creator
Copy link
Contributor

PhysicsNeMo Pull Request

Description

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 11, 2026

Greptile Overview

Greptile Summary

This PR adds GAFLARE (Geometry-Aware FLARE), an alternative attention mechanism to GALE for the GeoTransolver model.

Key Changes:

  • New gaflare.py module implementing FLARE-based attention with geometry-aware context
  • Modified gale.py to support dynamic attention type selection via attention_type parameter
  • Updated geotransolver.py to expose attention_type parameter (default: "GALE")

Critical Issues Found:

  • gaflare.py line 121: use_te parameter is hardcoded to False, ignoring the constructor argument
  • gaflare.py line 122: Attention scale hardcoded to 1.0 instead of proper dim_head**-0.5 scaling
  • gale.py line 393: Unsafe use of globals() for class selection - security risk and violates MOD-009 coding standard

Minor Issues:

  • Typo in docstring: "Origional" → "Original"
  • Missing newline at end of gaflare.py

Important Files Changed

Filename Overview
physicsnemo/experimental/models/geotransolver/gaflare.py New GAFLARE attention mechanism added. Critical issues: use_te parameter hardcoded to False (line 121), attention scale set to 1.0 instead of proper scaling. Minor: typo in docstring, missing newline at EOF.
physicsnemo/experimental/models/geotransolver/gale.py Integration changes to support attention type selection. Critical security issue: unsafe use of globals() for dynamic class instantiation (line 393) allows potential code execution.
physicsnemo/experimental/models/geotransolver/geotransolver.py Added attention_type parameter to allow switching between GALE and GAFLARE. Clean pass-through implementation with proper default value.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 121 to 122
self.use_te = False # te will disable FlashAttention for different size of q and k
self.scale = 1. #dim_head**-0.5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_te parameter ignored - hardcoded to False on line 121, making the use_te constructor parameter ineffective.

Suggested change
self.use_te = False # te will disable FlashAttention for different size of q and k
self.scale = 1. #dim_head**-0.5
super().__init__()
self.use_te = use_te and TE_AVAILABLE

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raise ValueError() added.

**kwargs,
):
self.use_te = False # te will disable FlashAttention for different size of q and k
self.scale = 1. #dim_head**-0.5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.scale set to 1.0 but commented code suggests it should be dim_head**-0.5 for proper attention scaling. Current implementation may affect attention quality.

Suggested change
self.scale = 1. #dim_head**-0.5
self.scale = dim_head**-0.5

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated: self.scale = 1. # FLARE scale is 1.0
FLARE code uses scale 1.0

context = torch.randn(2, 8, 64, 32).to(device)
gaflare = GAFLARE(dim=256, heads=8, dim_head=32, context_dim=32).to(device)
outputs = gaflare(x, context)
print(outputs[0].shape) No newline at end of file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing newline at end of file

Suggested change
print(outputs[0].shape)
print(outputs[0].shape)

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines 94 to 95
:class:`GALE` : Origional GeoTransolver GALE attention class.
:class:`GALE_block` : Transformer block using GAFLARE attention.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: "Origional" should be "Original"

Suggested change
:class:`GALE` : Origional GeoTransolver GALE attention class.
:class:`GALE_block` : Transformer block using GAFLARE attention.
:class:`GALE` : Original GeoTransolver GALE attention class.
:class:`GAFLARE_block` : Transformer block using GAFLARE attention.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Origional -> Original

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 11, 2026

Additional Comments (1)

physicsnemo/experimental/models/geotransolver/gale.py
Using globals() for dynamic class selection is unsafe - allows arbitrary code execution if attention_type comes from untrusted input. This violates MOD-009 (avoid string-based class selection). Consider explicit mapping:

ATTENTION_CLASSES = {
    "GALE": GALE,
    "GAFLARE": GAFLARE,
}
if attention_type not in ATTENTION_CLASSES:
    raise ValueError(f"Unknown attention_type: {attention_type}")
self.Attn = ATTENTION_CLASSES[attention_type](
    hidden_dim,
    heads=num_heads,
    # ... rest of arguments
)

@coreyjadams coreyjadams self-requested a review February 12, 2026 01:27
Copy link
Collaborator

@coreyjadams coreyjadams left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @dakhare-creator - thanks for opening this. Overall it seems reasonable. Good to put it in experimental. I made a couple comments in the code, but also:

  • it'd be unusual for use to publish a "GeoFlare" model without also making available "Flare". Can you add standard FLARE attention to the experimental/nn folder so we can do that too?
  • We can also probably easily create a flare.py itself for the model.
  • How much overlap in the math and attention blocks is there between PhysicsAttention and FLARE Attention? PhysicsAttention can be used on 2d and 3d data, could this? I strongly suspect yes. Can we make this modular so we can reuse as much as possible between attention layers?

Once we're ready, let's also invite the FLARE authors to take a look.

# GALE attention layer
self.Attn = GALE(
if attention_type in globals():
self.Attn = globals()[attention_type](
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's avoid using globals() here like this I think it's not needed. You can just do a match to get the attention type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated with match.

Comment on lines 247 to 253
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "cpu"
x = (torch.randn(2, 100, 256).to(device),)
context = torch.randn(2, 8, 64, 32).to(device)
gaflare = GAFLARE(dim=256, heads=8, dim_head=32, context_dim=32).to(device)
outputs = gaflare(x, context)
print(outputs[0].shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove this from the model code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

Comment on lines 35 to 37
TE_AVAILABLE = check_version_spec("transformer_engine", "0.1.0", hard_fail=False)
if TE_AVAILABLE:
import transformer_engine.pytorch as te
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a newer, simpler syntax we can use here with OptionalImport.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added following import to gaflare.py and gale.py:
from physicsnemo.core.version_check import OptionalImport
te = OptionalImport("transformer_engine.pytorch", "0.1.0")

import transformer_engine.pytorch as te


class GAFLARE(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we revisit this name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes for sure! Please let me know if you have any suggestion.

See Also
--------
:class:`GALE` : Original GeoTransolver GALE attention class.
:class:`GALE_block` : Transformer block using GAFLARE attention.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect cross-reference

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GALE_block calls GAFLARE and GALE is alternative attention mechanism.
Should we change it as -

    See Also
    --------
    :class:`GALE` : Original GeoTransolver GALE attention class.
    :class:`GALE_block` : Transformer block **that calls GALE or GAFLARE** attention.

# Linear projections for cross-attention
self.cross_q = linear_layer(dim_head, dim_head)
self.cross_k = linear_layer(context_dim, dim_head)
self.cross_v = linear_layer(context_dim, dim_head)
Copy link
Collaborator

@mnabian mnabian Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If context_dim=0, these layers are silently created, and will fail later (or produce garbage) if a non-empty context tensor is passed. A better design would skip creating the corss-attention layers entirely when context_dim=0. Or you can explicitly raise an error:

if context is not None and self.context_dim == 0:
      raise ValueError(...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added if context_dim > 0: to skip creating the cross-attention layers.

use_te: bool = True,
context_dim: int = 0,
):
self.use_te = False # te will disable FlashAttention for different size of q and k
Copy link
Collaborator

@mnabian mnabian Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest we explicitly raise an error when use_te=True, instead of silently setting it to False.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

if use_te:
            raise ValueError((
                "GAFLARE does not support Transformer Engine backend. "
                "Use use_te=False; TE disables FlashAttention for differing q/k sizes in FLARE attention."
            )

Copy link

@vpuri3 vpuri3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @dakhare-creator, I've left several comments in the code. Implementing FLARE straight on a geom problem can be error prone. I'd recommend first validating FLARE on one of the paper's benchmark problems

https://github.com/vpuri3/FLARE.py/blob/master/pdebench/models/flare.py

and then exposing the relevant layers to the geometry solvers interface.

"Use use_te=False; TE disables FlashAttention for differing q/k sizes in FLARE attention."
)
self.use_te = use_te
self.scale = 1. # FLARE scale is 1.0
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while 1 is the correct attn scale for dim_head = 8, this attn scale may not scale well for greater head dims. I'd recommend either exposing this setting to the user or setting it automatically like:

`self.scale = 1 if self.dim_head <= 8 else (self.dim_head ** -0.5)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the recommendation. Added!

Comment on lines +53 to +58
dim : int
Input dimension of the features.
heads : int, optional
Number of attention heads. Default is 8.
dim_head : int, optional
Dimension of each attention head. Default is 64.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it the case that dim = heads * dim_head? If so maybe an assert or warnings would be good in case the user. passes in an inconsistent setting.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this implementation, dim = heads * dim_head is not the case.
We explicitly define inner_dim = heads * dim_head and the dataflow looks like
dim -> inner_dim -> [heads, dim_head] -> inner_dim -> dim

Comment on lines 218 to 220
# Use PyTorch's scaled dot-product attention
z = [F.scaled_dot_product_attention(_G, _k, _v, scale=1.0) for _G, _k, _v in zip(G, k, v)]
self_attention = [F.scaled_dot_product_attention(_k, _G, _z, scale=1.0) for _k, _G, _z in zip(k, G, z)]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pass in self.scale.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines +137 to +138
self.self_k = linear_layer(dim_head, dim_head)
self.self_v = linear_layer(dim_head, dim_head)
Copy link

@vpuri3 vpuri3 Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've found that using a more expressive projection here really helps performance on PDE problems.

The tradeoff here is described in Appendix F under heading "Tradeoff between query dynamics and key/value expressivity" in the paper: https://arxiv.org/pdf/2508.12594.

For PDE problems, I've found that replacing FFN type layers (C -> 4C -> GeLU -> C) with deeper but narrower MLPs can help because the mapping is often smoother / more “function-approximation-like,” and gains come from expressive feature transforms more than from content-addressable routing/memorization.

Here's the full model definition I used in the experiments in the paper:

https://github.com/vpuri3/FLARE.py/blob/master/pdebench/models/flare.py

I understand that deep KV projections would increase parameter counts. To compensate for that, we have validated that FLARE performs at par with other models at smaller hidden sizes (C=64 for FLARE outperforms C=128 for transolver).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the suggestion. I’ll test this and include it in the next PR. I’ve also opened an issue to track the improvement: #1440

) for _cross_attention in cross_attention]
else:
cross_attention = [F.scaled_dot_product_attention(_q, k, v, scale=1.0) for _q in q]

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cross attention with FLARE has not been fully fleshed out. My plan for cross attention is outlined on page 71 (5.1.2 Aim 1(b): conditioning mechanism for dynamic PDE surrogates) of this document:

https://drive.google.com/file/d/1SNDjQ0gMSZmv0jg49S-risEoDiwE63aY/view?usp=sharing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the suggestion. I’ll test this and include it in the next PR. I’ve also opened an issue to track the improvement: #1440

@mnabian
Copy link
Collaborator

mnabian commented Feb 25, 2026

/blossom-ci

@coreyjadams
Copy link
Collaborator

Hi @dakhare-creator - with the geometry encoded FLARE, does this also enable standard FLARE attention and model?

@dakhare-creator
Copy link
Contributor Author

Hi @dakhare-creator - with the geometry encoded FLARE, does this also enable standard FLARE attention and model?

Hi @coreyjadams, Yes if the context is None, the attention mechanism will be standard FLARE.

@dakhare-creator dakhare-creator changed the title GeoFLARE: added GAFLARE, an alternate attention to GALE, for GeoTransolver GeoFLARE: added GALE_FA, an alternate attention to GALE, for GeoTransolver Feb 25, 2026
Comment on lines +38 to +43
class FLARE(nn.Module):
r"""FLARE: Fast Low-rank Attention Routing Engine attention layer.
Adopted:
- FLARE attention: Fast Low-rank Attention Routing Engine
paper: https://arxiv.org/abs/2508.12594

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably go into experimental/nn/flare_attention.py?

Comment on lines +260 to -244
self.Attn = FLARE(
hidden_dim,
spatial_shape=spatial_shape,
heads=num_heads,
dim_head=hidden_dim // num_heads,
dropout=dropout,
slice_num=slice_num,
n_global_queries=slice_num,
use_te=use_te,
plus=plus,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Flare support more multiple dimensions? How does it handle irregular vs. regular data?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, FLARE implemented here is only designed for point cloud and does not work for regular data. Maybe we should consider adding raise error for structure data input.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct - FLARE cannot make use of any connectivity information. If you feed a graph/mesh to flare, you have to only feed the unordered point cloud

from physicsnemo.core.meta import ModelMetaData
from physicsnemo.core.module import Module
from physicsnemo.core.version_check import check_version_spec
from physicsnemo.experimental.models.transolver.flare import FLARE
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

transolver can't import from experimental.

Make flare.py an experimental model instead?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants