Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed memory explosion during GeoTransolver inference on large meshes (2M+ cells).
The `broadcast_global_features` is now disabled in the datapipe for VTK inference,
with per-sub-batch broadcasting in `batched_inference_loop` to maintain compatibility
with models trained with `broadcast_global_features: true`.

### Security

### Dependencies
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,6 @@ def create_datapipe(
optional_keys = [
"include_normals",
"include_sdf",
"broadcast_global_features",
"include_geometry",
"geometry_sampling",
"translational_invariance",
Expand All @@ -465,6 +464,12 @@ def create_datapipe(
if cfg.data.get(key, None) is not None:
overrides[key] = cfg.data[key]

# IMPORTANT: Always disable broadcast_global_features in the datapipe for inference
# on large meshes. The broadcasting will be done per sub-batch in batched_inference_loop
# to avoid memory explosion on huge meshes.
# This works regardless of how the model was trained (broadcast true or false).
overrides["broadcast_global_features"] = False

# Create the datapipe with no resolution limit (we handle batching ourselves)
datapipe = TransolverDataPipe(
input_path=None, # We're not using the dataset iterator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,25 @@ def batched_inference_loop(
local_embeddings = batch["embeddings"][:, index_block]
local_fields = batch["fields"][:, index_block]

# fx does not need to be sliced for TransolverX:
# Handle fx (global features) based on model type:
if "geometry" not in batch.keys():
# Transolver path - fx is broadcast to all points, slice for sub-batch
local_fx = batch["fx"][:, index_block]
else:
local_fx = batch["fx"]
# GeoTransolver path - broadcast fx to sub-batch size on-demand
# (avoids memory explosion on large meshes by not pre-broadcasting to full mesh)
sub_batch_size = index_block.shape[0]
fx = batch["fx"]

# Normalize to 3D (B, tokens, features) - datapipe may add extra dims
while fx.ndim > 3:
fx = fx.squeeze(-2) # squeeze token dim from right

# Broadcast single-token fx, or slice full-mesh fx
if fx.shape[1] == 1:
local_fx = fx.expand(-1, sub_batch_size, -1)
else:
local_fx = fx[:, index_block]

local_batch = {
"fx": local_fx,
Expand Down