From 0fe32d64cb8a126a1cd76e3073af70ef481dfdb1 Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 4 May 2026 21:50:58 -0700 Subject: [PATCH 1/6] dynamic shape arg --- py/torch_tensorrt/_compile.py | 36 ++- py/torch_tensorrt/dynamo/_tracer.py | 11 +- .../dynamo/models/test_shared_dynamic_dim.py | 250 ++++++++++++++++++ 3 files changed, 289 insertions(+), 8 deletions(-) create mode 100644 tests/py/dynamo/models/test_shared_dynamic_dim.py diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 65986a276c..10f3542c7f 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -7,7 +7,18 @@ import platform import warnings from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Set, + Tuple, + Union, + cast, +) import torch from torch_tensorrt._enums import dtype @@ -191,6 +202,7 @@ def compile( arg_inputs: Optional[Sequence[Sequence[Any]]] = None, kwarg_inputs: Optional[Dict[str, Any]] = None, enabled_precisions: Optional[Set[Union[torch.dtype, dtype]]] = None, + dynamic_shapes: Optional[Any] = None, **kwargs: Any, ) -> ( torch.nn.Module | torch.jit.ScriptModule | torch.fx.GraphModule | Callable[..., Any] @@ -226,6 +238,14 @@ def compile( kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function. enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels ir (str): The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path) + dynamic_shapes (Any): Optional ``dynamic_shapes`` dict (or list / nested + structure) forwarded to ``torch.export.export``. Supply this to share a + ``Dim`` across multiple inputs (e.g. when ``input_ids`` and ``attention_mask`` + must have the same batch size at runtime). When omitted, dynamic shapes are + auto-inferred from per-input ``min_shape``/``max_shape`` and **each input gets + its own independent symbol** -- which fails ``torch.export``'s constraint + check for models that broadcast across these axes. Only consulted when + ``module`` is an ``nn.Module`` (ignored for ``ExportedProgram``). **kwargs: Additional settings for the specific requested strategy (See submodules for more info) Returns: @@ -296,7 +316,7 @@ def _fx_input_interface( return compiled_fx_module elif target_ir == _IRType.dynamo: # Prepare torch and torchtrt inputs - if arg_inputs is None and inputs is None: + if arg_inputs is None and inputs is None and not kwarg_inputs: raise AssertionError("'arg_inputs' and 'inputs' should not both be None.") elif arg_inputs is not None and inputs is not None: @@ -311,8 +331,10 @@ def _fx_input_interface( from torch_tensorrt.dynamo.utils import prepare_inputs - if not isinstance(arg_inputs, collections.abc.Sequence): - arg_inputs = [arg_inputs] # type: ignore + if arg_inputs is None: + arg_inputs = [] + elif not isinstance(arg_inputs, collections.abc.Sequence): + arg_inputs = [arg_inputs] torchtrt_arg_inputs = prepare_inputs(arg_inputs) torchtrt_kwarg_inputs = prepare_inputs(kwarg_inputs) @@ -324,6 +346,7 @@ def _fx_input_interface( module, torchtrt_arg_inputs, kwarg_inputs=torchtrt_kwarg_inputs, + dynamic_shapes=dynamic_shapes, **kwargs, ) trt_graph_module = dynamo_compile( @@ -830,8 +853,13 @@ def _all_are_input_objects(obj: Any) -> bool: f"Inferred dynamic_shapes from torch_tensorrt.Input objects with min/opt/max specifications: {dynamic_shapes}" ) +<<<<<<< HEAD arg_tensors = tuple(get_torch_inputs(arg_inputs, default_device())) # type: ignore[arg-type] kwarg_tensors = get_torch_inputs(kwarg_inputs, default_device()) # type: ignore[assignment] +======= + arg_tensors = tuple(get_torch_inputs(arg_inputs, default_device())) + kwarg_tensors = get_torch_inputs(kwarg_inputs, default_device()) +>>>>>>> 7fa5d838 (dynamic shape arg) else: # Mixed case: some inputs are Tensors, some are Input objects diff --git a/py/torch_tensorrt/dynamo/_tracer.py b/py/torch_tensorrt/dynamo/_tracer.py index 0595c6a8f9..e36c3b5240 100644 --- a/py/torch_tensorrt/dynamo/_tracer.py +++ b/py/torch_tensorrt/dynamo/_tracer.py @@ -19,6 +19,7 @@ def trace( *, arg_inputs: Optional[Tuple[Any, ...]] = None, kwarg_inputs: Optional[dict[Any, Any]] = None, + dynamic_shapes: Optional[Any] = None, **kwargs: Any, ) -> torch.export.ExportedProgram: """Exports a ``torch.export.ExportedProgram`` from a ``torch.nn.Module`` or ``torch.fx.GraphModule`` specifically targeting being compiled with Torch-TensorRT @@ -65,7 +66,7 @@ def trace( raise AssertionError( "'arg_inputs' and 'inputs' should not be used at the same time." ) - arg_inputs = inputs or arg_inputs + arg_inputs = inputs if inputs is not None else arg_inputs if kwarg_inputs is None: kwarg_inputs = {} @@ -73,9 +74,11 @@ def trace( device = to_torch_device(kwargs.get("device", default_device())) torch_arg_inputs = get_torch_inputs(arg_inputs, device) torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) - # Constructing dynamic shape list as a nested dict - dynamic_shapes = get_dynamic_shapes_args(mod, arg_inputs) - dynamic_shapes.update(get_dynamic_shapes_kwargs(kwarg_inputs)) + if dynamic_shapes is None: + # Auto-inferred dims are independent per input; pass dynamic_shapes + # explicitly to share a Dim across inputs. + dynamic_shapes = get_dynamic_shapes_args(mod, arg_inputs) + dynamic_shapes.update(get_dynamic_shapes_kwargs(kwarg_inputs)) exp_program = export( mod, tuple(torch_arg_inputs), diff --git a/tests/py/dynamo/models/test_shared_dynamic_dim.py b/tests/py/dynamo/models/test_shared_dynamic_dim.py new file mode 100644 index 0000000000..1f281e6851 --- /dev/null +++ b/tests/py/dynamo/models/test_shared_dynamic_dim.py @@ -0,0 +1,250 @@ +# type: ignore +""" +Tests for the ``dynamic_shapes=`` passthrough kwarg on ``torch_tensorrt.compile``. + +Background: when a model takes multiple inputs whose dynamic axes must be +**equal at runtime** (e.g. HF encoders with ``input_ids`` / ``attention_mask`` +both shaped ``[B, S]``), the legacy auto-inference path in +``dynamo/_tracer.py`` mints an *independent* ``Dim`` per input. ``torch.export`` +then fails its constraint check for any forward() that broadcasts across those +axes (here: ``embed(input_ids) * mask.unsqueeze(-1)``), raising +``ConstraintViolationError``. + +These tests exercise the new ``dynamic_shapes=`` passthrough that lets the +caller supply a shared ``Dim`` directly to ``torch_tensorrt.compile`` -- +mirroring the ``torch.export.export(dynamic_shapes=...)`` signature -- so the +shared-batch case compiles end to end without the caller having to pre-export +the module themselves. +""" +import unittest + +import pytest +import torch +import torch.nn as nn +import torch_tensorrt as torchtrt +from torch.export import Dim +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity + +assertions = unittest.TestCase() + + +class _SharedBatchEncoder(nn.Module): + """HF-style encoder stand-in: two int64 inputs sharing the batch axis. + + The ``embed(input_ids) * mask.unsqueeze(-1)`` broadcast forces + ``input_ids.size(0) == attention_mask.size(0)`` -- the relationship the + auto-inference path cannot express. + """ + + def __init__(self, vocab: int = 1024, hidden: int = 32): + super().__init__() + self.embed = nn.Embedding(vocab, hidden) + self.proj = nn.Linear(hidden, hidden) + + def forward(self, input_ids, attention_mask): + x = self.embed(input_ids) + mask = attention_mask.unsqueeze(-1).to(x.dtype) + return self.proj(x * mask) + + +def _kwarg_inputs(seq: int = 16, batch_min: int = 1, batch_max: int = 4): + return { + "input_ids": torchtrt.Input( + min_shape=(batch_min, seq), + opt_shape=(batch_max, seq), + max_shape=(batch_max, seq), + dtype=torch.int64, + name="input_ids", + ), + "attention_mask": torchtrt.Input( + min_shape=(batch_min, seq), + opt_shape=(batch_max, seq), + max_shape=(batch_max, seq), + dtype=torch.int64, + name="attention_mask", + ), + } + + +@pytest.mark.unit +@pytest.mark.critical +def test_dynamic_shapes_passthrough_with_shared_batch_dim(): + """With ``dynamic_shapes={..: {0: batch}, ..: {0: batch}}`` (one shared + ``Dim``), compile succeeds and the engine matches the eager model.""" + model = _SharedBatchEncoder().eval().cuda() + + batch = Dim("batch", min=1, max=4) + dynamic_shapes = { + "input_ids": {0: batch}, + "attention_mask": {0: batch}, + } + + trt_mod = torchtrt.compile( + model, + ir="dynamo", + kwarg_inputs=_kwarg_inputs(), + dynamic_shapes=dynamic_shapes, + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + ) + + # Sample at the optimization shape and at a smaller batch within the range. + for bs in (4, 2): + ids = torch.randint(0, 1024, (bs, 16), dtype=torch.int64, device="cuda") + mask = torch.ones((bs, 16), dtype=torch.int64, device="cuda") + + with torch.no_grad(): + ref = model(input_ids=ids, attention_mask=mask) + out = trt_mod(input_ids=ids, attention_mask=mask) + + cos_sim = cosine_similarity(ref, out) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + f"Shared-batch encoder out-of-tolerance at bs={bs}: cos_sim={cos_sim}", + ) + + +@pytest.mark.unit +def test_dynamic_shapes_passthrough_positional_tuple_form(): + """``torch.export`` also accepts ``dynamic_shapes`` as a tuple matching the + positional-args order. Verify the passthrough handles that form too.""" + model = _SharedBatchEncoder().eval().cuda() + + batch = Dim("batch", min=1, max=4) + seq = 16 + positional_inputs = [ + torchtrt.Input( + min_shape=(1, seq), + opt_shape=(4, seq), + max_shape=(4, seq), + dtype=torch.int64, + name="input_ids", + ), + torchtrt.Input( + min_shape=(1, seq), + opt_shape=(4, seq), + max_shape=(4, seq), + dtype=torch.int64, + name="attention_mask", + ), + ] + # Tuple form: one entry per positional arg, in declaration order. + dynamic_shapes = ({0: batch}, {0: batch}) + + trt_mod = torchtrt.compile( + model, + ir="dynamo", + inputs=positional_inputs, + dynamic_shapes=dynamic_shapes, + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + ) + + for bs in (4, 2): + ids = torch.randint(0, 1024, (bs, seq), dtype=torch.int64, device="cuda") + mask = torch.ones((bs, seq), dtype=torch.int64, device="cuda") + + with torch.no_grad(): + ref = model(ids, mask) + out = trt_mod(ids, mask) + + cos_sim = cosine_similarity(ref, out) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + f"Tuple-form dynamic_shapes out-of-tolerance at bs={bs}: cos_sim={cos_sim}", + ) + + +@pytest.mark.unit +def test_dynamic_shapes_passthrough_mixed_args_and_kwargs(): + """One positional input, one kwarg input, sharing a batch ``Dim``. Uses the + unified dict-by-parameter-name form, which spans both positional and keyword + parameters.""" + model = _SharedBatchEncoder().eval().cuda() + + batch = Dim("batch", min=1, max=4) + seq = 16 + + # input_ids passed positionally, attention_mask as a kwarg. + positional_inputs = [ + torchtrt.Input( + min_shape=(1, seq), + opt_shape=(4, seq), + max_shape=(4, seq), + dtype=torch.int64, + name="input_ids", + ), + ] + kwarg_inputs = { + "attention_mask": torchtrt.Input( + min_shape=(1, seq), + opt_shape=(4, seq), + max_shape=(4, seq), + dtype=torch.int64, + name="attention_mask", + ), + } + dynamic_shapes = { + "input_ids": {0: batch}, + "attention_mask": {0: batch}, + } + + trt_mod = torchtrt.compile( + model, + ir="dynamo", + inputs=positional_inputs, + kwarg_inputs=kwarg_inputs, + dynamic_shapes=dynamic_shapes, + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + ) + + for bs in (4, 2): + ids = torch.randint(0, 1024, (bs, seq), dtype=torch.int64, device="cuda") + mask = torch.ones((bs, seq), dtype=torch.int64, device="cuda") + + with torch.no_grad(): + ref = model(ids, attention_mask=mask) + out = trt_mod(ids, attention_mask=mask) + + cos_sim = cosine_similarity(ref, out) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + f"Mixed args/kwargs out-of-tolerance at bs={bs}: cos_sim={cos_sim}", + ) + + +@pytest.mark.unit +def test_dynamic_shapes_default_path_unchanged_for_static_inputs(): + """Sanity check: when ``dynamic_shapes=None`` and inputs are fully static, + behavior is unchanged from the legacy path.""" + + class StaticModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(8, 8) + + def forward(self, x): + return self.linear(x) + + model = StaticModel().eval().cuda() + trt_mod = torchtrt.compile( + model, + ir="dynamo", + inputs=[torchtrt.Input(shape=(2, 8), dtype=torch.float32, name="x")], + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + ) + x = torch.randn((2, 8), device="cuda") + with torch.no_grad(): + ref = model(x) + out = trt_mod(x) + assertions.assertTrue(cosine_similarity(ref, out) > COSINE_THRESHOLD) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 2811f516d03772c1818587f6469791729463d24d Mon Sep 17 00:00:00 2001 From: apbose Date: Wed, 3 Jun 2026 23:37:22 -0700 Subject: [PATCH 2/6] shared dynamic dims across inputs via Inputs --- py/torch_tensorrt/_Input.py | 47 +++++ py/torch_tensorrt/_compile.py | 32 +-- py/torch_tensorrt/dynamo/_tracer.py | 86 ++++++-- .../dynamo/models/test_shared_dynamic_dim.py | 197 ++++++++---------- 4 files changed, 221 insertions(+), 141 deletions(-) diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index b547afb278..27ee329f69 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -51,6 +51,9 @@ class _ShapeMode(Enum): torch_tensor: torch.Tensor = None name: str = "" is_shape_tensor: bool = False + name_dims: Dict[int, str] = ( + {} + ) #: Optional {axis_index: name} for dynamic axes. The same name across inputs is exported as one shared ``torch.export.Dim`` (e.g. a batch axis shared by ``input_ids`` and ``attention_mask``). def __init__(self, *args: Any, **kwargs: Any) -> None: """__init__ Method for torch_tensorrt.Input @@ -162,11 +165,22 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: "if you try to run inference with empty tensor inputs." ) + if "name_dims" in kwargs and kwargs["name_dims"]: + self.name_dims = Input._parse_name_dims( + kwargs["name_dims"], self.shape + ) + else: raise ValueError( f"Unexpected number of positional arguments for class Input \n Found {len(args)} arguments, expected either zero or a single positional arguments" ) + if kwargs.get("name_dims") and self.shape_mode != Input._ShapeMode.DYNAMIC: + raise ValueError( + "name_dims is only valid for dynamic inputs (min_shape/opt_shape/max_shape); " + "it has no meaning for a statically shaped Input." + ) + if "dtype" in kwargs: self.dtype = dtype._from(kwargs["dtype"]) @@ -261,6 +275,39 @@ def equivalent_spec(a: Input, b: Input) -> bool: ] return all(checks) + @staticmethod + def _parse_name_dims( + name_dims: Any, shape: Dict[str, Tuple[int, ...]] + ) -> Dict[int, str]: + """Validate and normalize the ``name_dims`` mapping ({axis: name}). + + Each named axis must be a valid index into the shape and must be + genuinely dynamic (``min != max``); a static axis cannot vary, so naming + it for cross-input sharing is a user error. + """ + if not isinstance(name_dims, dict): + raise TypeError( + f"name_dims must be a dict of {{axis_index: name}}, got {type(name_dims)}" + ) + rank = len(shape["min_shape"]) + parsed: Dict[int, str] = {} + for axis, dim_name in name_dims.items(): + if not isinstance(axis, int) or not (0 <= axis < rank): + raise ValueError( + f"name_dims key {axis!r} is not a valid axis index for an input of rank {rank}" + ) + if not isinstance(dim_name, str) or not dim_name: + raise ValueError( + f"name_dims value for axis {axis} must be a non-empty string, got {dim_name!r}" + ) + if shape["min_shape"][axis] == shape["max_shape"][axis]: + raise ValueError( + f"Axis {axis} named '{dim_name}' is static " + f"(min == max == {shape['min_shape'][axis]}); only dynamic axes can be named." + ) + parsed[axis] = dim_name + return parsed + @staticmethod def _supported_input_size_type(input_size: Any) -> bool: if isinstance(input_size, torch.Size): diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 10f3542c7f..74f91b5bd0 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -17,7 +17,6 @@ Set, Tuple, Union, - cast, ) import torch @@ -63,6 +62,7 @@ ) from torch_tensorrt.dynamo._defaults import default_device from torch_tensorrt.dynamo._tracer import ( + build_dim_registry, get_dynamic_shapes_args, get_dynamic_shapes_kwargs, ) @@ -202,7 +202,6 @@ def compile( arg_inputs: Optional[Sequence[Sequence[Any]]] = None, kwarg_inputs: Optional[Dict[str, Any]] = None, enabled_precisions: Optional[Set[Union[torch.dtype, dtype]]] = None, - dynamic_shapes: Optional[Any] = None, **kwargs: Any, ) -> ( torch.nn.Module | torch.jit.ScriptModule | torch.fx.GraphModule | Callable[..., Any] @@ -238,14 +237,6 @@ def compile( kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function. enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels ir (str): The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path) - dynamic_shapes (Any): Optional ``dynamic_shapes`` dict (or list / nested - structure) forwarded to ``torch.export.export``. Supply this to share a - ``Dim`` across multiple inputs (e.g. when ``input_ids`` and ``attention_mask`` - must have the same batch size at runtime). When omitted, dynamic shapes are - auto-inferred from per-input ``min_shape``/``max_shape`` and **each input gets - its own independent symbol** -- which fails ``torch.export``'s constraint - check for models that broadcast across these axes. Only consulted when - ``module`` is an ``nn.Module`` (ignored for ``ExportedProgram``). **kwargs: Additional settings for the specific requested strategy (See submodules for more info) Returns: @@ -324,7 +315,7 @@ def _fx_input_interface( "'arg_inputs' and 'inputs' should not be used at the same time." ) if inputs is not None: - arg_inputs = inputs + arg_inputs = inputs # type: ignore[assignment] if kwarg_inputs is None: kwarg_inputs = {} @@ -346,7 +337,6 @@ def _fx_input_interface( module, torchtrt_arg_inputs, kwarg_inputs=torchtrt_kwarg_inputs, - dynamic_shapes=dynamic_shapes, **kwargs, ) trt_graph_module = dynamo_compile( @@ -424,7 +414,7 @@ def cross_compile_for_windows( "'arg_inputs' and 'inputs' should not be used at the same time." ) - arg_inputs = inputs or arg_inputs + arg_inputs = inputs or arg_inputs # type: ignore[assignment] if kwarg_inputs is None: kwarg_inputs = {} @@ -524,7 +514,7 @@ def convert_method_to_trt_engine( raise AssertionError( "'arg_inputs' and 'inputs' should not be used at the same time." ) - arg_inputs = arg_inputs or inputs + arg_inputs = arg_inputs or inputs # type: ignore[assignment] module_type = _parse_module_type(module) target_ir = _get_target_fe(module_type, ir) @@ -844,8 +834,13 @@ def _all_are_input_objects(obj: Any) -> bool: "The explicit dynamic_shapes parameter takes precedence and Input shape specifications will be ignored." ) else: - inferred_dynamic_shapes = get_dynamic_shapes_args(module, arg_inputs) - inferred_dynamic_shapes.update(get_dynamic_shapes_kwargs(kwarg_inputs)) + dim_registry = build_dim_registry(arg_inputs, kwarg_inputs) + inferred_dynamic_shapes = get_dynamic_shapes_args( + module, arg_inputs, dim_registry + ) + inferred_dynamic_shapes.update( + get_dynamic_shapes_kwargs(kwarg_inputs, dim_registry) + ) if inferred_dynamic_shapes is not None: dynamic_shapes = inferred_dynamic_shapes @@ -853,13 +848,8 @@ def _all_are_input_objects(obj: Any) -> bool: f"Inferred dynamic_shapes from torch_tensorrt.Input objects with min/opt/max specifications: {dynamic_shapes}" ) -<<<<<<< HEAD - arg_tensors = tuple(get_torch_inputs(arg_inputs, default_device())) # type: ignore[arg-type] - kwarg_tensors = get_torch_inputs(kwarg_inputs, default_device()) # type: ignore[assignment] -======= arg_tensors = tuple(get_torch_inputs(arg_inputs, default_device())) kwarg_tensors = get_torch_inputs(kwarg_inputs, default_device()) ->>>>>>> 7fa5d838 (dynamic shape arg) else: # Mixed case: some inputs are Tensors, some are Input objects diff --git a/py/torch_tensorrt/dynamo/_tracer.py b/py/torch_tensorrt/dynamo/_tracer.py index e36c3b5240..3cd809e296 100644 --- a/py/torch_tensorrt/dynamo/_tracer.py +++ b/py/torch_tensorrt/dynamo/_tracer.py @@ -19,7 +19,6 @@ def trace( *, arg_inputs: Optional[Tuple[Any, ...]] = None, kwarg_inputs: Optional[dict[Any, Any]] = None, - dynamic_shapes: Optional[Any] = None, **kwargs: Any, ) -> torch.export.ExportedProgram: """Exports a ``torch.export.ExportedProgram`` from a ``torch.nn.Module`` or ``torch.fx.GraphModule`` specifically targeting being compiled with Torch-TensorRT @@ -74,11 +73,12 @@ def trace( device = to_torch_device(kwargs.get("device", default_device())) torch_arg_inputs = get_torch_inputs(arg_inputs, device) torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) - if dynamic_shapes is None: - # Auto-inferred dims are independent per input; pass dynamic_shapes - # explicitly to share a Dim across inputs. - dynamic_shapes = get_dynamic_shapes_args(mod, arg_inputs) - dynamic_shapes.update(get_dynamic_shapes_kwargs(kwarg_inputs)) + # Build dynamic shapes from the Input objects. Inputs carrying name_dims + # share a Dim across inputs via the registry; the rest get an independent + # per-input Dim. + dim_registry = build_dim_registry(arg_inputs, kwarg_inputs) + dynamic_shapes = get_dynamic_shapes_args(mod, arg_inputs, dim_registry) + dynamic_shapes.update(get_dynamic_shapes_kwargs(kwarg_inputs, dim_registry)) exp_program = export( mod, tuple(torch_arg_inputs), @@ -90,49 +90,109 @@ def trace( return exp_program -def get_dynamic_shapes_kwargs(inputs: Any) -> Union[dict[str, Any], list[Any]]: +def _collect_inputs(obj: Any) -> list[Input]: + """Flatten an arg/kwarg input structure into a list of Input objects.""" + if isinstance(obj, Input): + return [obj] + elif isinstance(obj, dict): + collected: list[Input] = [] + for v in obj.values(): + collected.extend(_collect_inputs(v)) + return collected + elif isinstance(obj, (list, tuple)): + collected = [] + for v in obj: + collected.extend(_collect_inputs(v)) + return collected + return [] + + +def build_dim_registry(arg_inputs: Any, kwarg_inputs: Any) -> dict[str, Any]: + """Build a ``{name: torch.export.Dim}`` registry from Input.name_dims. + + The same name appearing on multiple inputs yields a single shared ``Dim`` + instance, so ``torch.export`` treats those axes as one symbol. Conflicting + (min, max) ranges for the same name are rejected. + """ + registry: dict[str, Any] = {} + bounds: dict[str, tuple[int, int]] = {} + for inp in _collect_inputs(arg_inputs) + _collect_inputs(kwarg_inputs): + name_dims = getattr(inp, "name_dims", None) + if not name_dims or inp.shape_mode != Input._ShapeMode.DYNAMIC: + continue + assert isinstance(inp.shape, dict) + min_shape = inp.shape["min_shape"] + max_shape = inp.shape["max_shape"] + for axis, dim_name in name_dims.items(): + lo, hi = int(min_shape[axis]), int(max_shape[axis]) + if dim_name in bounds: + if bounds[dim_name] != (lo, hi): + raise ValueError( + f"Dimension name '{dim_name}' is used with conflicting ranges " + f"{bounds[dim_name]} and {(lo, hi)}. A shared named dimension " + f"must have identical (min, max) on every input that uses it." + ) + else: + bounds[dim_name] = (lo, hi) + registry[dim_name] = Dim(dim_name, min=lo, max=hi) + return registry + + +def get_dynamic_shapes_kwargs( + inputs: Any, dim_registry: Optional[dict[str, Any]] = None +) -> Union[dict[str, Any], list[Any]]: if isinstance(inputs, dict): dynamic_shapes_kwarg = {} for k, v in inputs.items(): - dynamic_shapes_kwarg[k] = get_dynamic_shapes_kwargs(v) + dynamic_shapes_kwarg[k] = get_dynamic_shapes_kwargs(v, dim_registry) return dynamic_shapes_kwarg elif isinstance(inputs, Input): - return get_dynamic_shapes(inputs) + return get_dynamic_shapes(inputs, dim_registry) elif isinstance(inputs, (list, tuple)): dynamic_shapes = [] for input in inputs: - dynamic_shapes.append(get_dynamic_shapes(input)) + dynamic_shapes.append(get_dynamic_shapes(input, dim_registry)) return dynamic_shapes raise TypeError(f"Unknown type {type(inputs)}.") -def get_dynamic_shapes_args(mod: torch.nn.Module, inputs: Any) -> dict[str, Any]: +def get_dynamic_shapes_args( + mod: torch.nn.Module, inputs: Any, dim_registry: Optional[dict[str, Any]] = None +) -> dict[str, Any]: # dynamic_shape is a dict and cannot work without keys. Here we use position argument name # in forward function as the name args = list(signature(mod.forward).parameters.keys()) dynamic_shapes = {} for input, input_name in zip(inputs, args[: len(inputs)]): - dynamic_shapes[input_name] = get_dynamic_shapes(input) + dynamic_shapes[input_name] = get_dynamic_shapes(input, dim_registry) return dynamic_shapes -def get_dynamic_shapes(input: Input) -> dict[Any, Any]: +def get_dynamic_shapes( + input: Input, dim_registry: Optional[dict[str, Any]] = None +) -> dict[Any, Any]: if not isinstance(input, Input): # If the input is torch.Tensor, no dynamic is needed. Return empty dict return {} else: dynamic_dims = {} if input.shape_mode == Input._ShapeMode.DYNAMIC: + assert isinstance(input.shape, dict) min_shape = input.shape["min_shape"] opt_shape = input.shape["opt_shape"] max_shape = input.shape["max_shape"] + name_dims = getattr(input, "name_dims", None) or {} assert len(min_shape) == len(opt_shape) == len(max_shape) for dim in range(len(min_shape)): if min_shape[dim] == opt_shape[dim] == max_shape[dim]: continue + elif dim_registry is not None and dim in name_dims: + # Named axis: reuse the shared Dim so axes with the same + # name across inputs become a single exported symbol. + dynamic_dims[dim] = dim_registry[name_dims[dim]] else: dynamic_dims[dim] = Dim( input.name + "_" + str(dim), diff --git a/tests/py/dynamo/models/test_shared_dynamic_dim.py b/tests/py/dynamo/models/test_shared_dynamic_dim.py index 1f281e6851..25ff7e292f 100644 --- a/tests/py/dynamo/models/test_shared_dynamic_dim.py +++ b/tests/py/dynamo/models/test_shared_dynamic_dim.py @@ -1,28 +1,27 @@ # type: ignore """ -Tests for the ``dynamic_shapes=`` passthrough kwarg on ``torch_tensorrt.compile``. +Tests for sharing a dynamic dimension across inputs via ``Input(name_dims=...)``. Background: when a model takes multiple inputs whose dynamic axes must be **equal at runtime** (e.g. HF encoders with ``input_ids`` / ``attention_mask`` -both shaped ``[B, S]``), the legacy auto-inference path in -``dynamo/_tracer.py`` mints an *independent* ``Dim`` per input. ``torch.export`` -then fails its constraint check for any forward() that broadcasts across those -axes (here: ``embed(input_ids) * mask.unsqueeze(-1)``), raising -``ConstraintViolationError``. - -These tests exercise the new ``dynamic_shapes=`` passthrough that lets the -caller supply a shared ``Dim`` directly to ``torch_tensorrt.compile`` -- -mirroring the ``torch.export.export(dynamic_shapes=...)`` signature -- so the -shared-batch case compiles end to end without the caller having to pre-export -the module themselves. +both shaped ``[B, S]``), naming each axis independently makes ``torch.export`` +mint an *independent* ``Dim`` per input. ``torch.export`` then fails its +constraint check for any forward() that broadcasts across those axes (here: +``embed(input_ids) * mask.unsqueeze(-1)``), raising ``ConstraintViolationError``. + +``Input(name_dims={axis: name})`` lets the caller tag a dynamic axis with a +name; the same name across inputs is exported as a single shared ``Dim``. All +the dynamic-shape intent lives on the ``Input`` objects -- no separate +``dynamic_shapes`` argument and no ``torch.export`` knowledge required at the +call site. """ + import unittest import pytest import torch import torch.nn as nn import torch_tensorrt as torchtrt -from torch.export import Dim from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity assertions = unittest.TestCase() @@ -32,8 +31,8 @@ class _SharedBatchEncoder(nn.Module): """HF-style encoder stand-in: two int64 inputs sharing the batch axis. The ``embed(input_ids) * mask.unsqueeze(-1)`` broadcast forces - ``input_ids.size(0) == attention_mask.size(0)`` -- the relationship the - auto-inference path cannot express. + ``input_ids.size(0) == attention_mask.size(0)`` -- the relationship a shared + named dimension expresses. """ def __init__(self, vocab: int = 1024, hidden: int = 32): @@ -47,43 +46,34 @@ def forward(self, input_ids, attention_mask): return self.proj(x * mask) -def _kwarg_inputs(seq: int = 16, batch_min: int = 1, batch_max: int = 4): - return { - "input_ids": torchtrt.Input( - min_shape=(batch_min, seq), - opt_shape=(batch_max, seq), - max_shape=(batch_max, seq), - dtype=torch.int64, - name="input_ids", - ), - "attention_mask": torchtrt.Input( - min_shape=(batch_min, seq), - opt_shape=(batch_max, seq), - max_shape=(batch_max, seq), - dtype=torch.int64, - name="attention_mask", - ), - } +def _named_input(name: str, seq: int = 16, batch_min: int = 1, batch_max: int = 4): + """A dynamic int64 Input whose batch axis (0) is named "B" for sharing.""" + return torchtrt.Input( + min_shape=(batch_min, seq), + opt_shape=(batch_max, seq), + max_shape=(batch_max, seq), + dtype=torch.int64, + name=name, + name_dims={0: "B"}, + ) @pytest.mark.unit @pytest.mark.critical -def test_dynamic_shapes_passthrough_with_shared_batch_dim(): - """With ``dynamic_shapes={..: {0: batch}, ..: {0: batch}}`` (one shared - ``Dim``), compile succeeds and the engine matches the eager model.""" +def test_name_dims_shared_batch_kwarg_inputs(): + """Shared batch axis declared via ``Input(name_dims={0: "B"})`` on both + kwarg inputs -- same name => one exported symbol; engine matches eager.""" model = _SharedBatchEncoder().eval().cuda() - batch = Dim("batch", min=1, max=4) - dynamic_shapes = { - "input_ids": {0: batch}, - "attention_mask": {0: batch}, + kwarg_inputs = { + "input_ids": _named_input("input_ids"), + "attention_mask": _named_input("attention_mask"), } trt_mod = torchtrt.compile( model, ir="dynamo", - kwarg_inputs=_kwarg_inputs(), - dynamic_shapes=dynamic_shapes, + kwarg_inputs=kwarg_inputs, min_block_size=1, cache_built_engines=False, reuse_cached_engines=False, @@ -101,50 +91,32 @@ def test_dynamic_shapes_passthrough_with_shared_batch_dim(): cos_sim = cosine_similarity(ref, out) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - f"Shared-batch encoder out-of-tolerance at bs={bs}: cos_sim={cos_sim}", + f"name_dims shared batch (kwargs) out-of-tolerance at bs={bs}: cos_sim={cos_sim}", ) @pytest.mark.unit -def test_dynamic_shapes_passthrough_positional_tuple_form(): - """``torch.export`` also accepts ``dynamic_shapes`` as a tuple matching the - positional-args order. Verify the passthrough handles that form too.""" +def test_name_dims_shared_batch_positional_inputs(): + """Same feature with positional ``inputs=[...]`` instead of kwargs.""" model = _SharedBatchEncoder().eval().cuda() - batch = Dim("batch", min=1, max=4) - seq = 16 positional_inputs = [ - torchtrt.Input( - min_shape=(1, seq), - opt_shape=(4, seq), - max_shape=(4, seq), - dtype=torch.int64, - name="input_ids", - ), - torchtrt.Input( - min_shape=(1, seq), - opt_shape=(4, seq), - max_shape=(4, seq), - dtype=torch.int64, - name="attention_mask", - ), + _named_input("input_ids"), + _named_input("attention_mask"), ] - # Tuple form: one entry per positional arg, in declaration order. - dynamic_shapes = ({0: batch}, {0: batch}) trt_mod = torchtrt.compile( model, ir="dynamo", inputs=positional_inputs, - dynamic_shapes=dynamic_shapes, min_block_size=1, cache_built_engines=False, reuse_cached_engines=False, ) for bs in (4, 2): - ids = torch.randint(0, 1024, (bs, seq), dtype=torch.int64, device="cuda") - mask = torch.ones((bs, seq), dtype=torch.int64, device="cuda") + ids = torch.randint(0, 1024, (bs, 16), dtype=torch.int64, device="cuda") + mask = torch.ones((bs, 16), dtype=torch.int64, device="cuda") with torch.no_grad(): ref = model(ids, mask) @@ -153,58 +125,28 @@ def test_dynamic_shapes_passthrough_positional_tuple_form(): cos_sim = cosine_similarity(ref, out) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - f"Tuple-form dynamic_shapes out-of-tolerance at bs={bs}: cos_sim={cos_sim}", + f"name_dims shared batch (positional) out-of-tolerance at bs={bs}: cos_sim={cos_sim}", ) @pytest.mark.unit -def test_dynamic_shapes_passthrough_mixed_args_and_kwargs(): - """One positional input, one kwarg input, sharing a batch ``Dim``. Uses the - unified dict-by-parameter-name form, which spans both positional and keyword - parameters.""" +def test_name_dims_shared_batch_mixed_args_and_kwargs(): + """input_ids passed positionally, attention_mask as a kwarg; both share "B".""" model = _SharedBatchEncoder().eval().cuda() - batch = Dim("batch", min=1, max=4) - seq = 16 - - # input_ids passed positionally, attention_mask as a kwarg. - positional_inputs = [ - torchtrt.Input( - min_shape=(1, seq), - opt_shape=(4, seq), - max_shape=(4, seq), - dtype=torch.int64, - name="input_ids", - ), - ] - kwarg_inputs = { - "attention_mask": torchtrt.Input( - min_shape=(1, seq), - opt_shape=(4, seq), - max_shape=(4, seq), - dtype=torch.int64, - name="attention_mask", - ), - } - dynamic_shapes = { - "input_ids": {0: batch}, - "attention_mask": {0: batch}, - } - trt_mod = torchtrt.compile( model, ir="dynamo", - inputs=positional_inputs, - kwarg_inputs=kwarg_inputs, - dynamic_shapes=dynamic_shapes, + inputs=[_named_input("input_ids")], + kwarg_inputs={"attention_mask": _named_input("attention_mask")}, min_block_size=1, cache_built_engines=False, reuse_cached_engines=False, ) for bs in (4, 2): - ids = torch.randint(0, 1024, (bs, seq), dtype=torch.int64, device="cuda") - mask = torch.ones((bs, seq), dtype=torch.int64, device="cuda") + ids = torch.randint(0, 1024, (bs, 16), dtype=torch.int64, device="cuda") + mask = torch.ones((bs, 16), dtype=torch.int64, device="cuda") with torch.no_grad(): ref = model(ids, attention_mask=mask) @@ -213,14 +155,55 @@ def test_dynamic_shapes_passthrough_mixed_args_and_kwargs(): cos_sim = cosine_similarity(ref, out) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - f"Mixed args/kwargs out-of-tolerance at bs={bs}: cos_sim={cos_sim}", + f"name_dims shared batch (mixed) out-of-tolerance at bs={bs}: cos_sim={cos_sim}", + ) + + +@pytest.mark.unit +def test_name_dims_conflicting_ranges_raises(): + """Same name with different (min, max) across inputs is a user error.""" + from torch_tensorrt.dynamo._tracer import build_dim_registry + + seq = 16 + inputs = { + "input_ids": torchtrt.Input( + min_shape=(1, seq), + opt_shape=(4, seq), + max_shape=(4, seq), + dtype=torch.int64, + name="input_ids", + name_dims={0: "B"}, + ), + "attention_mask": torchtrt.Input( + min_shape=(1, seq), + opt_shape=(8, seq), + max_shape=(8, seq), + dtype=torch.int64, + name="attention_mask", + name_dims={0: "B"}, + ), + } + with assertions.assertRaises(ValueError): + build_dim_registry((), inputs) + + +@pytest.mark.unit +def test_name_dims_rejected_on_static_axis(): + """Naming a static axis (min == max) is rejected at Input construction.""" + with assertions.assertRaises(ValueError): + torchtrt.Input( + min_shape=(1, 16), + opt_shape=(1, 16), + max_shape=(1, 16), + dtype=torch.int64, + name="x", + name_dims={0: "B"}, ) @pytest.mark.unit -def test_dynamic_shapes_default_path_unchanged_for_static_inputs(): - """Sanity check: when ``dynamic_shapes=None`` and inputs are fully static, - behavior is unchanged from the legacy path.""" +def test_default_path_unchanged_for_static_inputs(): + """Sanity check: a fully static input with no name_dims is unchanged.""" class StaticModel(nn.Module): def __init__(self): From d1e4c2885a2905b484dfea7fada3291022d61eca Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 4 Jun 2026 10:58:59 -0700 Subject: [PATCH 3/6] adding testcase --- tests/py/dynamo/models/test_shared_dynamic_dim.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/py/dynamo/models/test_shared_dynamic_dim.py b/tests/py/dynamo/models/test_shared_dynamic_dim.py index 25ff7e292f..b0f7c94d9b 100644 --- a/tests/py/dynamo/models/test_shared_dynamic_dim.py +++ b/tests/py/dynamo/models/test_shared_dynamic_dim.py @@ -201,6 +201,20 @@ def test_name_dims_rejected_on_static_axis(): ) +@pytest.mark.unit +def test_name_dims_rejected_on_out_of_range_axis(): + """An axis index outside the input's rank is rejected at construction.""" + with assertions.assertRaises(ValueError): + torchtrt.Input( + min_shape=(1, 16), + opt_shape=(4, 16), + max_shape=(4, 16), + dtype=torch.int64, + name="x", + name_dims={5: "B"}, # rank is 2; axis 5 does not exist + ) + + @pytest.mark.unit def test_default_path_unchanged_for_static_inputs(): """Sanity check: a fully static input with no name_dims is unchanged.""" From 6b3801b15033e89ebd84680c126a92136916121f Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 8 Jun 2026 13:02:12 -0700 Subject: [PATCH 4/6] replacing named_dims with shared_dims --- py/torch_tensorrt/_Input.py | 28 +++++++-------- py/torch_tensorrt/dynamo/_tracer.py | 16 ++++----- .../dynamo/models/test_shared_dynamic_dim.py | 36 +++++++++---------- 3 files changed, 40 insertions(+), 40 deletions(-) diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 27ee329f69..4e078daac4 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -51,7 +51,7 @@ class _ShapeMode(Enum): torch_tensor: torch.Tensor = None name: str = "" is_shape_tensor: bool = False - name_dims: Dict[int, str] = ( + shared_dims: Dict[int, str] = ( {} ) #: Optional {axis_index: name} for dynamic axes. The same name across inputs is exported as one shared ``torch.export.Dim`` (e.g. a batch axis shared by ``input_ids`` and ``attention_mask``). @@ -165,9 +165,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: "if you try to run inference with empty tensor inputs." ) - if "name_dims" in kwargs and kwargs["name_dims"]: - self.name_dims = Input._parse_name_dims( - kwargs["name_dims"], self.shape + if "shared_dims" in kwargs and kwargs["shared_dims"]: + self.shared_dims = Input._parse_shared_dims( + kwargs["shared_dims"], self.shape ) else: @@ -175,9 +175,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: f"Unexpected number of positional arguments for class Input \n Found {len(args)} arguments, expected either zero or a single positional arguments" ) - if kwargs.get("name_dims") and self.shape_mode != Input._ShapeMode.DYNAMIC: + if kwargs.get("shared_dims") and self.shape_mode != Input._ShapeMode.DYNAMIC: raise ValueError( - "name_dims is only valid for dynamic inputs (min_shape/opt_shape/max_shape); " + "shared_dims is only valid for dynamic inputs (min_shape/opt_shape/max_shape); " "it has no meaning for a statically shaped Input." ) @@ -276,29 +276,29 @@ def equivalent_spec(a: Input, b: Input) -> bool: return all(checks) @staticmethod - def _parse_name_dims( - name_dims: Any, shape: Dict[str, Tuple[int, ...]] + def _parse_shared_dims( + shared_dims: Any, shape: Dict[str, Tuple[int, ...]] ) -> Dict[int, str]: - """Validate and normalize the ``name_dims`` mapping ({axis: name}). + """Validate and normalize the ``shared_dims`` mapping ({axis: name}). Each named axis must be a valid index into the shape and must be genuinely dynamic (``min != max``); a static axis cannot vary, so naming it for cross-input sharing is a user error. """ - if not isinstance(name_dims, dict): + if not isinstance(shared_dims, dict): raise TypeError( - f"name_dims must be a dict of {{axis_index: name}}, got {type(name_dims)}" + f"shared_dims must be a dict of {{axis_index: name}}, got {type(shared_dims)}" ) rank = len(shape["min_shape"]) parsed: Dict[int, str] = {} - for axis, dim_name in name_dims.items(): + for axis, dim_name in shared_dims.items(): if not isinstance(axis, int) or not (0 <= axis < rank): raise ValueError( - f"name_dims key {axis!r} is not a valid axis index for an input of rank {rank}" + f"shared_dims key {axis!r} is not a valid axis index for an input of rank {rank}" ) if not isinstance(dim_name, str) or not dim_name: raise ValueError( - f"name_dims value for axis {axis} must be a non-empty string, got {dim_name!r}" + f"shared_dims value for axis {axis} must be a non-empty string, got {dim_name!r}" ) if shape["min_shape"][axis] == shape["max_shape"][axis]: raise ValueError( diff --git a/py/torch_tensorrt/dynamo/_tracer.py b/py/torch_tensorrt/dynamo/_tracer.py index 3cd809e296..399559347b 100644 --- a/py/torch_tensorrt/dynamo/_tracer.py +++ b/py/torch_tensorrt/dynamo/_tracer.py @@ -73,7 +73,7 @@ def trace( device = to_torch_device(kwargs.get("device", default_device())) torch_arg_inputs = get_torch_inputs(arg_inputs, device) torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) - # Build dynamic shapes from the Input objects. Inputs carrying name_dims + # Build dynamic shapes from the Input objects. Inputs carrying shared_dims # share a Dim across inputs via the registry; the rest get an independent # per-input Dim. dim_registry = build_dim_registry(arg_inputs, kwarg_inputs) @@ -108,7 +108,7 @@ def _collect_inputs(obj: Any) -> list[Input]: def build_dim_registry(arg_inputs: Any, kwarg_inputs: Any) -> dict[str, Any]: - """Build a ``{name: torch.export.Dim}`` registry from Input.name_dims. + """Build a ``{name: torch.export.Dim}`` registry from Input.shared_dims. The same name appearing on multiple inputs yields a single shared ``Dim`` instance, so ``torch.export`` treats those axes as one symbol. Conflicting @@ -117,13 +117,13 @@ def build_dim_registry(arg_inputs: Any, kwarg_inputs: Any) -> dict[str, Any]: registry: dict[str, Any] = {} bounds: dict[str, tuple[int, int]] = {} for inp in _collect_inputs(arg_inputs) + _collect_inputs(kwarg_inputs): - name_dims = getattr(inp, "name_dims", None) - if not name_dims or inp.shape_mode != Input._ShapeMode.DYNAMIC: + shared_dims = getattr(inp, "shared_dims", None) + if not shared_dims or inp.shape_mode != Input._ShapeMode.DYNAMIC: continue assert isinstance(inp.shape, dict) min_shape = inp.shape["min_shape"] max_shape = inp.shape["max_shape"] - for axis, dim_name in name_dims.items(): + for axis, dim_name in shared_dims.items(): lo, hi = int(min_shape[axis]), int(max_shape[axis]) if dim_name in bounds: if bounds[dim_name] != (lo, hi): @@ -184,15 +184,15 @@ def get_dynamic_shapes( min_shape = input.shape["min_shape"] opt_shape = input.shape["opt_shape"] max_shape = input.shape["max_shape"] - name_dims = getattr(input, "name_dims", None) or {} + shared_dims = getattr(input, "shared_dims", None) or {} assert len(min_shape) == len(opt_shape) == len(max_shape) for dim in range(len(min_shape)): if min_shape[dim] == opt_shape[dim] == max_shape[dim]: continue - elif dim_registry is not None and dim in name_dims: + elif dim_registry is not None and dim in shared_dims: # Named axis: reuse the shared Dim so axes with the same # name across inputs become a single exported symbol. - dynamic_dims[dim] = dim_registry[name_dims[dim]] + dynamic_dims[dim] = dim_registry[shared_dims[dim]] else: dynamic_dims[dim] = Dim( input.name + "_" + str(dim), diff --git a/tests/py/dynamo/models/test_shared_dynamic_dim.py b/tests/py/dynamo/models/test_shared_dynamic_dim.py index b0f7c94d9b..9d1ac28fbb 100644 --- a/tests/py/dynamo/models/test_shared_dynamic_dim.py +++ b/tests/py/dynamo/models/test_shared_dynamic_dim.py @@ -1,6 +1,6 @@ # type: ignore """ -Tests for sharing a dynamic dimension across inputs via ``Input(name_dims=...)``. +Tests for sharing a dynamic dimension across inputs via ``Input(shared_dims=...)``. Background: when a model takes multiple inputs whose dynamic axes must be **equal at runtime** (e.g. HF encoders with ``input_ids`` / ``attention_mask`` @@ -9,7 +9,7 @@ constraint check for any forward() that broadcasts across those axes (here: ``embed(input_ids) * mask.unsqueeze(-1)``), raising ``ConstraintViolationError``. -``Input(name_dims={axis: name})`` lets the caller tag a dynamic axis with a +``Input(shared_dims={axis: name})`` lets the caller tag a dynamic axis with a name; the same name across inputs is exported as a single shared ``Dim``. All the dynamic-shape intent lives on the ``Input`` objects -- no separate ``dynamic_shapes`` argument and no ``torch.export`` knowledge required at the @@ -54,14 +54,14 @@ def _named_input(name: str, seq: int = 16, batch_min: int = 1, batch_max: int = max_shape=(batch_max, seq), dtype=torch.int64, name=name, - name_dims={0: "B"}, + shared_dims={0: "B"}, ) @pytest.mark.unit @pytest.mark.critical -def test_name_dims_shared_batch_kwarg_inputs(): - """Shared batch axis declared via ``Input(name_dims={0: "B"})`` on both +def test_shared_dims_shared_batch_kwarg_inputs(): + """Shared batch axis declared via ``Input(shared_dims={0: "B"})`` on both kwarg inputs -- same name => one exported symbol; engine matches eager.""" model = _SharedBatchEncoder().eval().cuda() @@ -91,12 +91,12 @@ def test_name_dims_shared_batch_kwarg_inputs(): cos_sim = cosine_similarity(ref, out) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - f"name_dims shared batch (kwargs) out-of-tolerance at bs={bs}: cos_sim={cos_sim}", + f"shared_dims shared batch (kwargs) out-of-tolerance at bs={bs}: cos_sim={cos_sim}", ) @pytest.mark.unit -def test_name_dims_shared_batch_positional_inputs(): +def test_shared_dims_shared_batch_positional_inputs(): """Same feature with positional ``inputs=[...]`` instead of kwargs.""" model = _SharedBatchEncoder().eval().cuda() @@ -125,12 +125,12 @@ def test_name_dims_shared_batch_positional_inputs(): cos_sim = cosine_similarity(ref, out) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - f"name_dims shared batch (positional) out-of-tolerance at bs={bs}: cos_sim={cos_sim}", + f"shared_dims shared batch (positional) out-of-tolerance at bs={bs}: cos_sim={cos_sim}", ) @pytest.mark.unit -def test_name_dims_shared_batch_mixed_args_and_kwargs(): +def test_shared_dims_shared_batch_mixed_args_and_kwargs(): """input_ids passed positionally, attention_mask as a kwarg; both share "B".""" model = _SharedBatchEncoder().eval().cuda() @@ -155,12 +155,12 @@ def test_name_dims_shared_batch_mixed_args_and_kwargs(): cos_sim = cosine_similarity(ref, out) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, - f"name_dims shared batch (mixed) out-of-tolerance at bs={bs}: cos_sim={cos_sim}", + f"shared_dims shared batch (mixed) out-of-tolerance at bs={bs}: cos_sim={cos_sim}", ) @pytest.mark.unit -def test_name_dims_conflicting_ranges_raises(): +def test_shared_dims_conflicting_ranges_raises(): """Same name with different (min, max) across inputs is a user error.""" from torch_tensorrt.dynamo._tracer import build_dim_registry @@ -172,7 +172,7 @@ def test_name_dims_conflicting_ranges_raises(): max_shape=(4, seq), dtype=torch.int64, name="input_ids", - name_dims={0: "B"}, + shared_dims={0: "B"}, ), "attention_mask": torchtrt.Input( min_shape=(1, seq), @@ -180,7 +180,7 @@ def test_name_dims_conflicting_ranges_raises(): max_shape=(8, seq), dtype=torch.int64, name="attention_mask", - name_dims={0: "B"}, + shared_dims={0: "B"}, ), } with assertions.assertRaises(ValueError): @@ -188,7 +188,7 @@ def test_name_dims_conflicting_ranges_raises(): @pytest.mark.unit -def test_name_dims_rejected_on_static_axis(): +def test_shared_dims_rejected_on_static_axis(): """Naming a static axis (min == max) is rejected at Input construction.""" with assertions.assertRaises(ValueError): torchtrt.Input( @@ -197,12 +197,12 @@ def test_name_dims_rejected_on_static_axis(): max_shape=(1, 16), dtype=torch.int64, name="x", - name_dims={0: "B"}, + shared_dims={0: "B"}, ) @pytest.mark.unit -def test_name_dims_rejected_on_out_of_range_axis(): +def test_shared_dims_rejected_on_out_of_range_axis(): """An axis index outside the input's rank is rejected at construction.""" with assertions.assertRaises(ValueError): torchtrt.Input( @@ -211,13 +211,13 @@ def test_name_dims_rejected_on_out_of_range_axis(): max_shape=(4, 16), dtype=torch.int64, name="x", - name_dims={5: "B"}, # rank is 2; axis 5 does not exist + shared_dims={5: "B"}, # rank is 2; axis 5 does not exist ) @pytest.mark.unit def test_default_path_unchanged_for_static_inputs(): - """Sanity check: a fully static input with no name_dims is unchanged.""" + """Sanity check: a fully static input with no shared_dims is unchanged.""" class StaticModel(nn.Module): def __init__(self): From f907b6463f0984eb04ff1af51c13998a9c96b21c Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 11 Jun 2026 10:55:00 -0700 Subject: [PATCH 5/6] adding examples for shared dims in both index.rst and a short section in dynamic_shapes.rst --- .../user_guide/compilation/dynamic_shapes.rst | 43 +++++ docsrc/user_guide/compilation/index.rst | 1 + .../dynamo/shared_dynamic_dims_example.py | 181 ++++++++++++++++++ 3 files changed, 225 insertions(+) create mode 100644 examples/dynamo/shared_dynamic_dims_example.py diff --git a/docsrc/user_guide/compilation/dynamic_shapes.rst b/docsrc/user_guide/compilation/dynamic_shapes.rst index 8e11e923b6..5d27fe8483 100644 --- a/docsrc/user_guide/compilation/dynamic_shapes.rst +++ b/docsrc/user_guide/compilation/dynamic_shapes.rst @@ -78,6 +78,49 @@ Here's a simple example that exports a matmul layer with some restrictions on dy # Run inference trt_gm(*inputs) +Sharing a Dynamic Dimension Across Multiple Inputs +--------------------------------------------------- + +HuggingFace-style encoders and similar models take multiple inputs (e.g. +``input_ids`` and ``attention_mask``) whose dynamic axes **must be equal at +runtime**. If you assign an independent dynamic dimension to each input, +``torch.export`` detects that the two independent symbols are forced equal by +the model's forward pass (e.g. a broadcast) and raises a +``ConstraintViolationError``. + +``torch_tensorrt.Input(shared_dims={axis: name})`` solves this without any +manual ``torch.export`` work. Axes that share the same name across inputs are +exported as a single ``torch.export.Dim``, so the equality constraint is +satisfied automatically. + +.. code-block:: python + + import torch + import torch_tensorrt + + model = MyHFEncoder().eval().cuda() + + inputs = [ + torch_tensorrt.Input( + min_shape=(1, 16), opt_shape=(4, 16), max_shape=(8, 16), + dtype=torch.int64, name="input_ids", + shared_dims={0: "B"}, # axis 0 named "B" + ), + torch_tensorrt.Input( + min_shape=(1, 16), opt_shape=(4, 16), max_shape=(8, 16), + dtype=torch.int64, name="attention_mask", + shared_dims={0: "B"}, # same name → same Dim + ), + ] + + trt_model = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs) + +The same name on the same axis index across different inputs produces one +shared ``Dim``; different names produce independent ``Dim``\s. Multiple axes +can be shared simultaneously with ``shared_dims={0: "B", 1: "S"}``. + +See the full runnable example: :ref:`shared_dynamic_dims`. + Dynamic shapes using torch.compile (JIT) ------------------------------------ diff --git a/docsrc/user_guide/compilation/index.rst b/docsrc/user_guide/compilation/index.rst index 154e0a61ae..fe1c49ed98 100644 --- a/docsrc/user_guide/compilation/index.rst +++ b/docsrc/user_guide/compilation/index.rst @@ -13,4 +13,5 @@ How Torch-TensorRT compiles models: the JIT ``torch.compile`` path, the AOT compilation_settings dynamic_shapes Example: Compiling Models with Dynamic Input Shapes <../../tutorials/_rendered_examples/dynamo/compile_with_dynamic_inputs> + Example: Sharing Dynamic Dimensions Across Inputs <../../tutorials/_rendered_examples/dynamo/shared_dynamic_dims_example> unsupported_ops diff --git a/examples/dynamo/shared_dynamic_dims_example.py b/examples/dynamo/shared_dynamic_dims_example.py new file mode 100644 index 0000000000..979e5df022 --- /dev/null +++ b/examples/dynamo/shared_dynamic_dims_example.py @@ -0,0 +1,181 @@ +""" +.. _shared_dynamic_dims: + +Sharing Dynamic Dimensions Across Inputs +========================================================== + +When a model takes multiple inputs whose dynamic axes must be **equal at +runtime** — for example, HuggingFace-style encoders where ``input_ids`` and +``attention_mask`` are both shaped ``[batch, seq_len]`` — naively assigning an +independent dynamic dimension to each input causes ``torch.export`` to raise a +``ConstraintViolationError``. The exporter detects that the two independent +symbols are forced equal by the model's forward pass (e.g. a broadcast) and +rejects the export. + +``torch_tensorrt.Input(shared_dims={axis: name})`` solves this: axes that share +the same name across inputs are exported as a single ``torch.export.Dim``, so +the equality constraint is satisfied automatically. All dynamic-shape intent +lives on the ``Input`` objects — no separate ``dynamic_shapes`` argument or +``torch.export`` knowledge is required at the call site. +""" + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +import torch +import torch.nn as nn +import torch_tensorrt +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity + +# %% +# Define a HuggingFace-style encoder whose two inputs share the batch axis. +# The ``embed * mask`` broadcast forces ``input_ids.shape[0] == +# attention_mask.shape[0]`` at every forward call — exactly the pattern that +# triggers ``ConstraintViolationError`` when the batch axis is exported as two +# independent ``Dim`` objects. + + +class SharedDimEncoder(nn.Module): + def __init__(self, vocab: int = 1024, hidden: int = 64): + super().__init__() + self.embed = nn.Embedding(vocab, hidden) + self.proj = nn.Linear(hidden, hidden) + + def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor): + x = self.embed(input_ids) # [B, S, hidden] + mask = attention_mask.unsqueeze(-1).to(x.dtype) # [B, S, 1] + return self.proj(x * mask) # [B, S, hidden] + + +model = SharedDimEncoder().cuda().eval() + +# %% +# Without ``shared_dims`` — raises ``ConstraintViolationError`` +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# Using independent ``Input`` objects like this would fail at export time: +# +# .. code-block:: python +# +# inputs = [ +# torch_tensorrt.Input(min_shape=(1,16), opt_shape=(4,16), max_shape=(8,16), dtype=torch.int64), +# torch_tensorrt.Input(min_shape=(1,16), opt_shape=(4,16), max_shape=(8,16), dtype=torch.int64), +# ] +# # torch.export mints independent symbols s0, s1 for the batch axis of +# # each input. The broadcast forces Eq(s0, s1), which the exporter rejects. +# +# %% +# With ``shared_dims`` — correct approach (positional inputs) +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# Annotate the batch axis (axis 0) with the same name ``"B"`` on both inputs. +# Torch-TensorRT creates a single shared ``torch.export.Dim("B")`` for that +# axis so the equality constraint is satisfied up front. + +inputs = [ + torch_tensorrt.Input( + min_shape=(1, 16), + opt_shape=(4, 16), + max_shape=(8, 16), + dtype=torch.int64, + name="input_ids", + shared_dims={0: "B"}, + ), + torch_tensorrt.Input( + min_shape=(1, 16), + opt_shape=(4, 16), + max_shape=(8, 16), + dtype=torch.int64, + name="attention_mask", + shared_dims={0: "B"}, + ), +] + +trt_model = torch_tensorrt.compile( + model, + ir="dynamo", + inputs=inputs, + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, +) + +# %% +# Verify correctness at multiple batch sizes within the declared range. + +for batch_size in (4, 2, 1): + ids = torch.randint(0, 1024, (batch_size, 16), dtype=torch.int64, device="cuda") + mask = torch.ones((batch_size, 16), dtype=torch.int64, device="cuda") + + with torch.no_grad(): + ref = model(ids, mask) + out = trt_model(ids, mask) + + cos_sim = cosine_similarity(ref, out) + assert ( + cos_sim > COSINE_THRESHOLD + ), f"Numerical mismatch at batch_size={batch_size}: cos_sim={cos_sim:.4f}" + print(f"batch_size={batch_size} cos_sim={cos_sim:.6f} ✓") + +# %% +# With ``shared_dims`` — kwarg inputs +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# The same feature works with ``kwarg_inputs``, which is the natural form for +# HuggingFace models whose ``forward`` signature uses keyword arguments. + +kwarg_inputs = { + "input_ids": torch_tensorrt.Input( + min_shape=(1, 16), + opt_shape=(4, 16), + max_shape=(8, 16), + dtype=torch.int64, + name="input_ids", + shared_dims={0: "B"}, + ), + "attention_mask": torch_tensorrt.Input( + min_shape=(1, 16), + opt_shape=(4, 16), + max_shape=(8, 16), + dtype=torch.int64, + name="attention_mask", + shared_dims={0: "B"}, + ), +} + +trt_model_kwargs = torch_tensorrt.compile( + model, + ir="dynamo", + kwarg_inputs=kwarg_inputs, + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, +) + +ids = torch.randint(0, 1024, (4, 16), dtype=torch.int64, device="cuda") +mask = torch.ones((4, 16), dtype=torch.int64, device="cuda") + +with torch.no_grad(): + ref = model(input_ids=ids, attention_mask=mask) + out = trt_model_kwargs(input_ids=ids, attention_mask=mask) + +cos_sim = cosine_similarity(ref, out) +assert cos_sim > COSINE_THRESHOLD, f"kwarg path mismatch: cos_sim={cos_sim:.4f}" +print(f"kwarg_inputs path cos_sim={cos_sim:.6f} ✓") + +# %% +# Sharing multiple axes +# ^^^^^^^^^^^^^^^^^^^^^^ +# +# If both batch and sequence length are dynamic and must be shared, annotate +# both axes on each input: +# +# .. code-block:: python +# +# shared_dims={0: "B", 1: "S"} +# +# The same name on the same axis across different inputs produces one shared +# ``Dim``; different names on different axes produce independent ``Dim``\s. + +print("\nAll checks passed.") From c85ffe6361df57b218a079b46d69030fb990cefc Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 11 Jun 2026 12:01:20 -0700 Subject: [PATCH 6/6] namedtuple in input --- py/torch_tensorrt/_Input.py | 35 ++++++ .../dynamo/models/test_shared_dynamic_dim.py | 111 ++++++++++++++++++ 2 files changed, 146 insertions(+) diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 4e078daac4..ef484cb4bc 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -165,7 +165,42 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: "if you try to run inference with empty tensor inputs." ) + # Namedtuple shape API: field names encode per-axis dimension names. + # Convert to shared_dims so _tracer.py needs no changes — axes with the + # same name across inputs become one shared torch.export.Dim. + if hasattr(kwargs["min_shape"], "_fields"): + fields = kwargs["min_shape"]._fields + if not ( + hasattr(kwargs["opt_shape"], "_fields") + and hasattr(kwargs["max_shape"], "_fields") + ): + raise TypeError( + "If min_shape is a namedtuple, opt_shape and max_shape must also be namedtuples" + ) + if not ( + kwargs["opt_shape"]._fields == fields + and kwargs["max_shape"]._fields == fields + ): + raise ValueError( + "min_shape, opt_shape, max_shape namedtuples must have identical field names" + ) + # Only tag dynamic axes (min != max); static axes need no shared Dim. + self.shared_dims = { + i: name + for i, name in enumerate(fields) + if self.shape["min_shape"][i] != self.shape["max_shape"][i] + } + if "shared_dims" in kwargs and kwargs["shared_dims"]: + if hasattr(kwargs["min_shape"], "_fields"): + # Not allowed: + # S = namedtuple('S', ['b', 'c']) + # Input(min_shape=S(1,2), opt_shape=S(4,2), max_shape=S(8,2), + # shared_dims={0: "b"}) # ← redundant and ambiguous + raise ValueError( + "Cannot specify both a namedtuple min_shape and shared_dims; " + "use one or the other to name dynamic axes." + ) self.shared_dims = Input._parse_shared_dims( kwargs["shared_dims"], self.shape ) diff --git a/tests/py/dynamo/models/test_shared_dynamic_dim.py b/tests/py/dynamo/models/test_shared_dynamic_dim.py index 9d1ac28fbb..814987332d 100644 --- a/tests/py/dynamo/models/test_shared_dynamic_dim.py +++ b/tests/py/dynamo/models/test_shared_dynamic_dim.py @@ -243,5 +243,116 @@ def forward(self, x): assertions.assertTrue(cosine_similarity(ref, out) > COSINE_THRESHOLD) +# --------------------------------------------------------------------------- +# Namedtuple shape API tests +# --------------------------------------------------------------------------- +# +# The namedtuple API lets users name axes by using a namedtuple as the shape +# spec. Field names that appear on multiple inputs are automatically treated +# as a shared torch.export.Dim — no explicit shared_dims kwarg required. +# +# input_shape1 = namedtuple('S', ['n', 'c', 'h', 'w']) +# input_shape2 = namedtuple('S', ['c', 'seq']) +# Both have field 'c' → one shared Dim("c"). + +from collections import namedtuple + + +@pytest.mark.unit +@pytest.mark.critical +def test_namedtuple_shared_batch_positional_inputs(): + """Namedtuple field 'c' shared across two inputs — same as shared_dims={...:'c'}.""" + model = _SharedBatchEncoder().eval().cuda() + + # seq is static + S1 = namedtuple("shape", ["c", "seq"]) + + positional_inputs = [ + torchtrt.Input( + min_shape=S1(1, 16), + opt_shape=S1(4, 16), + max_shape=S1(4, 16), + dtype=torch.int64, + name="input_ids", + ), + torchtrt.Input( + min_shape=S1(1, 16), + opt_shape=S1(4, 16), + max_shape=S1(4, 16), + dtype=torch.int64, + name="attention_mask", + ), + ] + + trt_mod = torchtrt.compile( + model, + ir="dynamo", + inputs=positional_inputs, + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + ) + + for bs in (4, 2): + ids = torch.randint(0, 1024, (bs, 16), dtype=torch.int64, device="cuda") + mask = torch.ones((bs, 16), dtype=torch.int64, device="cuda") + with torch.no_grad(): + ref = model(ids, mask) + out = trt_mod(ids, mask) + cos_sim = cosine_similarity(ref, out) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + f"namedtuple shared batch (positional) out-of-tolerance at bs={bs}: cos_sim={cos_sim}", + ) + + +@pytest.mark.unit +def test_namedtuple_static_axes_skipped(): + """Static axes (min==max) in a namedtuple are not added to shared_dims.""" + S = namedtuple("shape", ["batch", "seq"]) + inp = torchtrt.Input( + min_shape=S(1, 16), + opt_shape=S(4, 16), + max_shape=S(4, 16), + dtype=torch.int64, + name="x", + ) + # 'seq' axis: min=max=16 → static, must not appear in shared_dims + assertions.assertNotIn(1, inp.shared_dims) + # 'batch' axis: min=1, max=4 → dynamic, must appear + assertions.assertIn(0, inp.shared_dims) + assertions.assertEqual(inp.shared_dims[0], "batch") + + +@pytest.mark.unit +def test_namedtuple_mismatched_fields_raises(): + """opt_shape namedtuple with different fields than min_shape is rejected.""" + S1 = namedtuple("shape", ["b", "c"]) + S2 = namedtuple("shape", ["b", "seq"]) + with assertions.assertRaises(ValueError): + torchtrt.Input( + min_shape=S1(1, 16), + opt_shape=S2(4, 16), + max_shape=S1(4, 16), + dtype=torch.int64, + name="x", + ) + + +@pytest.mark.unit +def test_namedtuple_and_shared_dims_together_raises(): + """Passing both a namedtuple shape and shared_dims kwarg is rejected.""" + S = namedtuple("shape", ["b", "c"]) + with assertions.assertRaises(ValueError): + torchtrt.Input( + min_shape=S(1, 16), + opt_shape=S(4, 16), + max_shape=S(4, 16), + dtype=torch.int64, + name="x", + shared_dims={0: "b"}, + ) + + if __name__ == "__main__": pytest.main([__file__, "-v"])