Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions docsrc/user_guide/compilation/dynamic_shapes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,49 @@ Here's a simple example that exports a matmul layer with some restrictions on dy
# Run inference
trt_gm(*inputs)

Sharing a Dynamic Dimension Across Multiple Inputs
---------------------------------------------------

HuggingFace-style encoders and similar models take multiple inputs (e.g.
``input_ids`` and ``attention_mask``) whose dynamic axes **must be equal at
runtime**. If you assign an independent dynamic dimension to each input,
``torch.export`` detects that the two independent symbols are forced equal by
the model's forward pass (e.g. a broadcast) and raises a
``ConstraintViolationError``.

``torch_tensorrt.Input(shared_dims={axis: name})`` solves this without any
manual ``torch.export`` work. Axes that share the same name across inputs are
exported as a single ``torch.export.Dim``, so the equality constraint is
satisfied automatically.

.. code-block:: python

import torch
import torch_tensorrt

model = MyHFEncoder().eval().cuda()

inputs = [
torch_tensorrt.Input(
min_shape=(1, 16), opt_shape=(4, 16), max_shape=(8, 16),
dtype=torch.int64, name="input_ids",
shared_dims={0: "B"}, # axis 0 named "B"
),
torch_tensorrt.Input(
min_shape=(1, 16), opt_shape=(4, 16), max_shape=(8, 16),
dtype=torch.int64, name="attention_mask",
shared_dims={0: "B"}, # same name → same Dim
),
]

trt_model = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)

The same name on the same axis index across different inputs produces one
shared ``Dim``; different names produce independent ``Dim``\s. Multiple axes
can be shared simultaneously with ``shared_dims={0: "B", 1: "S"}``.

See the full runnable example: :ref:`shared_dynamic_dims`.

Dynamic shapes using torch.compile (JIT)
------------------------------------

Expand Down
1 change: 1 addition & 0 deletions docsrc/user_guide/compilation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ How Torch-TensorRT compiles models: the JIT ``torch.compile`` path, the AOT
compilation_settings
dynamic_shapes
Example: Compiling Models with Dynamic Input Shapes <../../tutorials/_rendered_examples/dynamo/compile_with_dynamic_inputs>
Example: Sharing Dynamic Dimensions Across Inputs <../../tutorials/_rendered_examples/dynamo/shared_dynamic_dims_example>
unsupported_ops
181 changes: 181 additions & 0 deletions examples/dynamo/shared_dynamic_dims_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
"""
.. _shared_dynamic_dims:

Sharing Dynamic Dimensions Across Inputs
==========================================================

When a model takes multiple inputs whose dynamic axes must be **equal at
runtime** — for example, HuggingFace-style encoders where ``input_ids`` and
``attention_mask`` are both shaped ``[batch, seq_len]`` — naively assigning an
independent dynamic dimension to each input causes ``torch.export`` to raise a
``ConstraintViolationError``. The exporter detects that the two independent
symbols are forced equal by the model's forward pass (e.g. a broadcast) and
rejects the export.

``torch_tensorrt.Input(shared_dims={axis: name})`` solves this: axes that share
the same name across inputs are exported as a single ``torch.export.Dim``, so
the equality constraint is satisfied automatically. All dynamic-shape intent
lives on the ``Input`` objects — no separate ``dynamic_shapes`` argument or
``torch.export`` knowledge is required at the call site.
"""

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

import torch
import torch.nn as nn
import torch_tensorrt
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity

# %%
# Define a HuggingFace-style encoder whose two inputs share the batch axis.
# The ``embed * mask`` broadcast forces ``input_ids.shape[0] ==
# attention_mask.shape[0]`` at every forward call — exactly the pattern that
# triggers ``ConstraintViolationError`` when the batch axis is exported as two
# independent ``Dim`` objects.


class SharedDimEncoder(nn.Module):
def __init__(self, vocab: int = 1024, hidden: int = 64):
super().__init__()
self.embed = nn.Embedding(vocab, hidden)
self.proj = nn.Linear(hidden, hidden)

def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor):
x = self.embed(input_ids) # [B, S, hidden]
mask = attention_mask.unsqueeze(-1).to(x.dtype) # [B, S, 1]
return self.proj(x * mask) # [B, S, hidden]


model = SharedDimEncoder().cuda().eval()

# %%
# Without ``shared_dims`` — raises ``ConstraintViolationError``
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# Using independent ``Input`` objects like this would fail at export time:
#
# .. code-block:: python
#
# inputs = [
# torch_tensorrt.Input(min_shape=(1,16), opt_shape=(4,16), max_shape=(8,16), dtype=torch.int64),
# torch_tensorrt.Input(min_shape=(1,16), opt_shape=(4,16), max_shape=(8,16), dtype=torch.int64),
# ]
# # torch.export mints independent symbols s0, s1 for the batch axis of
# # each input. The broadcast forces Eq(s0, s1), which the exporter rejects.
#
# %%
# With ``shared_dims`` — correct approach (positional inputs)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# Annotate the batch axis (axis 0) with the same name ``"B"`` on both inputs.
# Torch-TensorRT creates a single shared ``torch.export.Dim("B")`` for that
# axis so the equality constraint is satisfied up front.

inputs = [
torch_tensorrt.Input(
min_shape=(1, 16),
opt_shape=(4, 16),
max_shape=(8, 16),
dtype=torch.int64,
name="input_ids",
shared_dims={0: "B"},
),
torch_tensorrt.Input(
min_shape=(1, 16),
opt_shape=(4, 16),
max_shape=(8, 16),
dtype=torch.int64,
name="attention_mask",
shared_dims={0: "B"},
),
]

trt_model = torch_tensorrt.compile(
model,
ir="dynamo",
inputs=inputs,
min_block_size=1,
cache_built_engines=False,
reuse_cached_engines=False,
)

# %%
# Verify correctness at multiple batch sizes within the declared range.

for batch_size in (4, 2, 1):
ids = torch.randint(0, 1024, (batch_size, 16), dtype=torch.int64, device="cuda")
mask = torch.ones((batch_size, 16), dtype=torch.int64, device="cuda")

with torch.no_grad():
ref = model(ids, mask)
out = trt_model(ids, mask)

cos_sim = cosine_similarity(ref, out)
assert (
cos_sim > COSINE_THRESHOLD
), f"Numerical mismatch at batch_size={batch_size}: cos_sim={cos_sim:.4f}"
print(f"batch_size={batch_size} cos_sim={cos_sim:.6f} ✓")

# %%
# With ``shared_dims`` — kwarg inputs
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# The same feature works with ``kwarg_inputs``, which is the natural form for
# HuggingFace models whose ``forward`` signature uses keyword arguments.

kwarg_inputs = {
"input_ids": torch_tensorrt.Input(
min_shape=(1, 16),
opt_shape=(4, 16),
max_shape=(8, 16),
dtype=torch.int64,
name="input_ids",
shared_dims={0: "B"},
),
"attention_mask": torch_tensorrt.Input(
min_shape=(1, 16),
opt_shape=(4, 16),
max_shape=(8, 16),
dtype=torch.int64,
name="attention_mask",
shared_dims={0: "B"},
),
}

trt_model_kwargs = torch_tensorrt.compile(
model,
ir="dynamo",
kwarg_inputs=kwarg_inputs,
min_block_size=1,
cache_built_engines=False,
reuse_cached_engines=False,
)

ids = torch.randint(0, 1024, (4, 16), dtype=torch.int64, device="cuda")
mask = torch.ones((4, 16), dtype=torch.int64, device="cuda")

with torch.no_grad():
ref = model(input_ids=ids, attention_mask=mask)
out = trt_model_kwargs(input_ids=ids, attention_mask=mask)

cos_sim = cosine_similarity(ref, out)
assert cos_sim > COSINE_THRESHOLD, f"kwarg path mismatch: cos_sim={cos_sim:.4f}"
print(f"kwarg_inputs path cos_sim={cos_sim:.6f} ✓")

# %%
# Sharing multiple axes
# ^^^^^^^^^^^^^^^^^^^^^^
#
# If both batch and sequence length are dynamic and must be shared, annotate
# both axes on each input:
#
# .. code-block:: python
#
# shared_dims={0: "B", 1: "S"}
#
# The same name on the same axis across different inputs produces one shared
# ``Dim``; different names on different axes produce independent ``Dim``\s.

print("\nAll checks passed.")
82 changes: 82 additions & 0 deletions py/torch_tensorrt/_Input.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class _ShapeMode(Enum):
torch_tensor: torch.Tensor = None
name: str = ""
is_shape_tensor: bool = False
shared_dims: Dict[int, str] = (
{}
) #: Optional {axis_index: name} for dynamic axes. The same name across inputs is exported as one shared ``torch.export.Dim`` (e.g. a batch axis shared by ``input_ids`` and ``attention_mask``).

def __init__(self, *args: Any, **kwargs: Any) -> None:
"""__init__ Method for torch_tensorrt.Input
Expand Down Expand Up @@ -162,11 +165,57 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
"if you try to run inference with empty tensor inputs."
)

# Namedtuple shape API: field names encode per-axis dimension names.
# Convert to shared_dims so _tracer.py needs no changes — axes with the
# same name across inputs become one shared torch.export.Dim.
if hasattr(kwargs["min_shape"], "_fields"):
fields = kwargs["min_shape"]._fields
if not (
hasattr(kwargs["opt_shape"], "_fields")
and hasattr(kwargs["max_shape"], "_fields")
):
raise TypeError(
"If min_shape is a namedtuple, opt_shape and max_shape must also be namedtuples"
)
if not (
kwargs["opt_shape"]._fields == fields
and kwargs["max_shape"]._fields == fields
):
raise ValueError(
"min_shape, opt_shape, max_shape namedtuples must have identical field names"
)
# Only tag dynamic axes (min != max); static axes need no shared Dim.
self.shared_dims = {
i: name
for i, name in enumerate(fields)
if self.shape["min_shape"][i] != self.shape["max_shape"][i]
}

if "shared_dims" in kwargs and kwargs["shared_dims"]:
if hasattr(kwargs["min_shape"], "_fields"):
# Not allowed:
# S = namedtuple('S', ['b', 'c'])
# Input(min_shape=S(1,2), opt_shape=S(4,2), max_shape=S(8,2),
# shared_dims={0: "b"}) # ← redundant and ambiguous
raise ValueError(
"Cannot specify both a namedtuple min_shape and shared_dims; "
"use one or the other to name dynamic axes."
)
self.shared_dims = Input._parse_shared_dims(
kwargs["shared_dims"], self.shape
)

else:
raise ValueError(
f"Unexpected number of positional arguments for class Input \n Found {len(args)} arguments, expected either zero or a single positional arguments"
)

if kwargs.get("shared_dims") and self.shape_mode != Input._ShapeMode.DYNAMIC:
raise ValueError(
"shared_dims is only valid for dynamic inputs (min_shape/opt_shape/max_shape); "
"it has no meaning for a statically shaped Input."
)

if "dtype" in kwargs:
self.dtype = dtype._from(kwargs["dtype"])

Expand Down Expand Up @@ -261,6 +310,39 @@ def equivalent_spec(a: Input, b: Input) -> bool:
]
return all(checks)

@staticmethod
def _parse_shared_dims(
shared_dims: Any, shape: Dict[str, Tuple[int, ...]]
) -> Dict[int, str]:
"""Validate and normalize the ``shared_dims`` mapping ({axis: name}).

Each named axis must be a valid index into the shape and must be
genuinely dynamic (``min != max``); a static axis cannot vary, so naming
it for cross-input sharing is a user error.
"""
if not isinstance(shared_dims, dict):
raise TypeError(
f"shared_dims must be a dict of {{axis_index: name}}, got {type(shared_dims)}"
)
rank = len(shape["min_shape"])
parsed: Dict[int, str] = {}
for axis, dim_name in shared_dims.items():
if not isinstance(axis, int) or not (0 <= axis < rank):
raise ValueError(
f"shared_dims key {axis!r} is not a valid axis index for an input of rank {rank}"
)
if not isinstance(dim_name, str) or not dim_name:
raise ValueError(
f"shared_dims value for axis {axis} must be a non-empty string, got {dim_name!r}"
)
if shape["min_shape"][axis] == shape["max_shape"][axis]:
raise ValueError(
f"Axis {axis} named '{dim_name}' is static "
f"(min == max == {shape['min_shape'][axis]}); only dynamic axes can be named."
)
parsed[axis] = dim_name
return parsed

@staticmethod
def _supported_input_size_type(input_size: Any) -> bool:
if isinstance(input_size, torch.Size):
Expand Down
Loading
Loading