diff --git a/docsrc/user_guide/compilation/dynamic_shapes.rst b/docsrc/user_guide/compilation/dynamic_shapes.rst index 8e11e923b6..5d27fe8483 100644 --- a/docsrc/user_guide/compilation/dynamic_shapes.rst +++ b/docsrc/user_guide/compilation/dynamic_shapes.rst @@ -78,6 +78,49 @@ Here's a simple example that exports a matmul layer with some restrictions on dy # Run inference trt_gm(*inputs) +Sharing a Dynamic Dimension Across Multiple Inputs +--------------------------------------------------- + +HuggingFace-style encoders and similar models take multiple inputs (e.g. +``input_ids`` and ``attention_mask``) whose dynamic axes **must be equal at +runtime**. If you assign an independent dynamic dimension to each input, +``torch.export`` detects that the two independent symbols are forced equal by +the model's forward pass (e.g. a broadcast) and raises a +``ConstraintViolationError``. + +``torch_tensorrt.Input(shared_dims={axis: name})`` solves this without any +manual ``torch.export`` work. Axes that share the same name across inputs are +exported as a single ``torch.export.Dim``, so the equality constraint is +satisfied automatically. + +.. code-block:: python + + import torch + import torch_tensorrt + + model = MyHFEncoder().eval().cuda() + + inputs = [ + torch_tensorrt.Input( + min_shape=(1, 16), opt_shape=(4, 16), max_shape=(8, 16), + dtype=torch.int64, name="input_ids", + shared_dims={0: "B"}, # axis 0 named "B" + ), + torch_tensorrt.Input( + min_shape=(1, 16), opt_shape=(4, 16), max_shape=(8, 16), + dtype=torch.int64, name="attention_mask", + shared_dims={0: "B"}, # same name → same Dim + ), + ] + + trt_model = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs) + +The same name on the same axis index across different inputs produces one +shared ``Dim``; different names produce independent ``Dim``\s. Multiple axes +can be shared simultaneously with ``shared_dims={0: "B", 1: "S"}``. + +See the full runnable example: :ref:`shared_dynamic_dims`. + Dynamic shapes using torch.compile (JIT) ------------------------------------ diff --git a/docsrc/user_guide/compilation/index.rst b/docsrc/user_guide/compilation/index.rst index 154e0a61ae..fe1c49ed98 100644 --- a/docsrc/user_guide/compilation/index.rst +++ b/docsrc/user_guide/compilation/index.rst @@ -13,4 +13,5 @@ How Torch-TensorRT compiles models: the JIT ``torch.compile`` path, the AOT compilation_settings dynamic_shapes Example: Compiling Models with Dynamic Input Shapes <../../tutorials/_rendered_examples/dynamo/compile_with_dynamic_inputs> + Example: Sharing Dynamic Dimensions Across Inputs <../../tutorials/_rendered_examples/dynamo/shared_dynamic_dims_example> unsupported_ops diff --git a/examples/dynamo/shared_dynamic_dims_example.py b/examples/dynamo/shared_dynamic_dims_example.py new file mode 100644 index 0000000000..979e5df022 --- /dev/null +++ b/examples/dynamo/shared_dynamic_dims_example.py @@ -0,0 +1,181 @@ +""" +.. _shared_dynamic_dims: + +Sharing Dynamic Dimensions Across Inputs +========================================================== + +When a model takes multiple inputs whose dynamic axes must be **equal at +runtime** — for example, HuggingFace-style encoders where ``input_ids`` and +``attention_mask`` are both shaped ``[batch, seq_len]`` — naively assigning an +independent dynamic dimension to each input causes ``torch.export`` to raise a +``ConstraintViolationError``. The exporter detects that the two independent +symbols are forced equal by the model's forward pass (e.g. a broadcast) and +rejects the export. + +``torch_tensorrt.Input(shared_dims={axis: name})`` solves this: axes that share +the same name across inputs are exported as a single ``torch.export.Dim``, so +the equality constraint is satisfied automatically. All dynamic-shape intent +lives on the ``Input`` objects — no separate ``dynamic_shapes`` argument or +``torch.export`` knowledge is required at the call site. +""" + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +import torch +import torch.nn as nn +import torch_tensorrt +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity + +# %% +# Define a HuggingFace-style encoder whose two inputs share the batch axis. +# The ``embed * mask`` broadcast forces ``input_ids.shape[0] == +# attention_mask.shape[0]`` at every forward call — exactly the pattern that +# triggers ``ConstraintViolationError`` when the batch axis is exported as two +# independent ``Dim`` objects. + + +class SharedDimEncoder(nn.Module): + def __init__(self, vocab: int = 1024, hidden: int = 64): + super().__init__() + self.embed = nn.Embedding(vocab, hidden) + self.proj = nn.Linear(hidden, hidden) + + def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor): + x = self.embed(input_ids) # [B, S, hidden] + mask = attention_mask.unsqueeze(-1).to(x.dtype) # [B, S, 1] + return self.proj(x * mask) # [B, S, hidden] + + +model = SharedDimEncoder().cuda().eval() + +# %% +# Without ``shared_dims`` — raises ``ConstraintViolationError`` +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# Using independent ``Input`` objects like this would fail at export time: +# +# .. code-block:: python +# +# inputs = [ +# torch_tensorrt.Input(min_shape=(1,16), opt_shape=(4,16), max_shape=(8,16), dtype=torch.int64), +# torch_tensorrt.Input(min_shape=(1,16), opt_shape=(4,16), max_shape=(8,16), dtype=torch.int64), +# ] +# # torch.export mints independent symbols s0, s1 for the batch axis of +# # each input. The broadcast forces Eq(s0, s1), which the exporter rejects. +# +# %% +# With ``shared_dims`` — correct approach (positional inputs) +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# Annotate the batch axis (axis 0) with the same name ``"B"`` on both inputs. +# Torch-TensorRT creates a single shared ``torch.export.Dim("B")`` for that +# axis so the equality constraint is satisfied up front. + +inputs = [ + torch_tensorrt.Input( + min_shape=(1, 16), + opt_shape=(4, 16), + max_shape=(8, 16), + dtype=torch.int64, + name="input_ids", + shared_dims={0: "B"}, + ), + torch_tensorrt.Input( + min_shape=(1, 16), + opt_shape=(4, 16), + max_shape=(8, 16), + dtype=torch.int64, + name="attention_mask", + shared_dims={0: "B"}, + ), +] + +trt_model = torch_tensorrt.compile( + model, + ir="dynamo", + inputs=inputs, + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, +) + +# %% +# Verify correctness at multiple batch sizes within the declared range. + +for batch_size in (4, 2, 1): + ids = torch.randint(0, 1024, (batch_size, 16), dtype=torch.int64, device="cuda") + mask = torch.ones((batch_size, 16), dtype=torch.int64, device="cuda") + + with torch.no_grad(): + ref = model(ids, mask) + out = trt_model(ids, mask) + + cos_sim = cosine_similarity(ref, out) + assert ( + cos_sim > COSINE_THRESHOLD + ), f"Numerical mismatch at batch_size={batch_size}: cos_sim={cos_sim:.4f}" + print(f"batch_size={batch_size} cos_sim={cos_sim:.6f} ✓") + +# %% +# With ``shared_dims`` — kwarg inputs +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# The same feature works with ``kwarg_inputs``, which is the natural form for +# HuggingFace models whose ``forward`` signature uses keyword arguments. + +kwarg_inputs = { + "input_ids": torch_tensorrt.Input( + min_shape=(1, 16), + opt_shape=(4, 16), + max_shape=(8, 16), + dtype=torch.int64, + name="input_ids", + shared_dims={0: "B"}, + ), + "attention_mask": torch_tensorrt.Input( + min_shape=(1, 16), + opt_shape=(4, 16), + max_shape=(8, 16), + dtype=torch.int64, + name="attention_mask", + shared_dims={0: "B"}, + ), +} + +trt_model_kwargs = torch_tensorrt.compile( + model, + ir="dynamo", + kwarg_inputs=kwarg_inputs, + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, +) + +ids = torch.randint(0, 1024, (4, 16), dtype=torch.int64, device="cuda") +mask = torch.ones((4, 16), dtype=torch.int64, device="cuda") + +with torch.no_grad(): + ref = model(input_ids=ids, attention_mask=mask) + out = trt_model_kwargs(input_ids=ids, attention_mask=mask) + +cos_sim = cosine_similarity(ref, out) +assert cos_sim > COSINE_THRESHOLD, f"kwarg path mismatch: cos_sim={cos_sim:.4f}" +print(f"kwarg_inputs path cos_sim={cos_sim:.6f} ✓") + +# %% +# Sharing multiple axes +# ^^^^^^^^^^^^^^^^^^^^^^ +# +# If both batch and sequence length are dynamic and must be shared, annotate +# both axes on each input: +# +# .. code-block:: python +# +# shared_dims={0: "B", 1: "S"} +# +# The same name on the same axis across different inputs produces one shared +# ``Dim``; different names on different axes produce independent ``Dim``\s. + +print("\nAll checks passed.") diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index b547afb278..ef484cb4bc 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -51,6 +51,9 @@ class _ShapeMode(Enum): torch_tensor: torch.Tensor = None name: str = "" is_shape_tensor: bool = False + shared_dims: Dict[int, str] = ( + {} + ) #: Optional {axis_index: name} for dynamic axes. The same name across inputs is exported as one shared ``torch.export.Dim`` (e.g. a batch axis shared by ``input_ids`` and ``attention_mask``). def __init__(self, *args: Any, **kwargs: Any) -> None: """__init__ Method for torch_tensorrt.Input @@ -162,11 +165,57 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: "if you try to run inference with empty tensor inputs." ) + # Namedtuple shape API: field names encode per-axis dimension names. + # Convert to shared_dims so _tracer.py needs no changes — axes with the + # same name across inputs become one shared torch.export.Dim. + if hasattr(kwargs["min_shape"], "_fields"): + fields = kwargs["min_shape"]._fields + if not ( + hasattr(kwargs["opt_shape"], "_fields") + and hasattr(kwargs["max_shape"], "_fields") + ): + raise TypeError( + "If min_shape is a namedtuple, opt_shape and max_shape must also be namedtuples" + ) + if not ( + kwargs["opt_shape"]._fields == fields + and kwargs["max_shape"]._fields == fields + ): + raise ValueError( + "min_shape, opt_shape, max_shape namedtuples must have identical field names" + ) + # Only tag dynamic axes (min != max); static axes need no shared Dim. + self.shared_dims = { + i: name + for i, name in enumerate(fields) + if self.shape["min_shape"][i] != self.shape["max_shape"][i] + } + + if "shared_dims" in kwargs and kwargs["shared_dims"]: + if hasattr(kwargs["min_shape"], "_fields"): + # Not allowed: + # S = namedtuple('S', ['b', 'c']) + # Input(min_shape=S(1,2), opt_shape=S(4,2), max_shape=S(8,2), + # shared_dims={0: "b"}) # ← redundant and ambiguous + raise ValueError( + "Cannot specify both a namedtuple min_shape and shared_dims; " + "use one or the other to name dynamic axes." + ) + self.shared_dims = Input._parse_shared_dims( + kwargs["shared_dims"], self.shape + ) + else: raise ValueError( f"Unexpected number of positional arguments for class Input \n Found {len(args)} arguments, expected either zero or a single positional arguments" ) + if kwargs.get("shared_dims") and self.shape_mode != Input._ShapeMode.DYNAMIC: + raise ValueError( + "shared_dims is only valid for dynamic inputs (min_shape/opt_shape/max_shape); " + "it has no meaning for a statically shaped Input." + ) + if "dtype" in kwargs: self.dtype = dtype._from(kwargs["dtype"]) @@ -261,6 +310,39 @@ def equivalent_spec(a: Input, b: Input) -> bool: ] return all(checks) + @staticmethod + def _parse_shared_dims( + shared_dims: Any, shape: Dict[str, Tuple[int, ...]] + ) -> Dict[int, str]: + """Validate and normalize the ``shared_dims`` mapping ({axis: name}). + + Each named axis must be a valid index into the shape and must be + genuinely dynamic (``min != max``); a static axis cannot vary, so naming + it for cross-input sharing is a user error. + """ + if not isinstance(shared_dims, dict): + raise TypeError( + f"shared_dims must be a dict of {{axis_index: name}}, got {type(shared_dims)}" + ) + rank = len(shape["min_shape"]) + parsed: Dict[int, str] = {} + for axis, dim_name in shared_dims.items(): + if not isinstance(axis, int) or not (0 <= axis < rank): + raise ValueError( + f"shared_dims key {axis!r} is not a valid axis index for an input of rank {rank}" + ) + if not isinstance(dim_name, str) or not dim_name: + raise ValueError( + f"shared_dims value for axis {axis} must be a non-empty string, got {dim_name!r}" + ) + if shape["min_shape"][axis] == shape["max_shape"][axis]: + raise ValueError( + f"Axis {axis} named '{dim_name}' is static " + f"(min == max == {shape['min_shape'][axis]}); only dynamic axes can be named." + ) + parsed[axis] = dim_name + return parsed + @staticmethod def _supported_input_size_type(input_size: Any) -> bool: if isinstance(input_size, torch.Size): diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 65986a276c..74f91b5bd0 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -7,7 +7,17 @@ import platform import warnings from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Set, + Tuple, + Union, +) import torch from torch_tensorrt._enums import dtype @@ -52,6 +62,7 @@ ) from torch_tensorrt.dynamo._defaults import default_device from torch_tensorrt.dynamo._tracer import ( + build_dim_registry, get_dynamic_shapes_args, get_dynamic_shapes_kwargs, ) @@ -296,7 +307,7 @@ def _fx_input_interface( return compiled_fx_module elif target_ir == _IRType.dynamo: # Prepare torch and torchtrt inputs - if arg_inputs is None and inputs is None: + if arg_inputs is None and inputs is None and not kwarg_inputs: raise AssertionError("'arg_inputs' and 'inputs' should not both be None.") elif arg_inputs is not None and inputs is not None: @@ -304,15 +315,17 @@ def _fx_input_interface( "'arg_inputs' and 'inputs' should not be used at the same time." ) if inputs is not None: - arg_inputs = inputs + arg_inputs = inputs # type: ignore[assignment] if kwarg_inputs is None: kwarg_inputs = {} from torch_tensorrt.dynamo.utils import prepare_inputs - if not isinstance(arg_inputs, collections.abc.Sequence): - arg_inputs = [arg_inputs] # type: ignore + if arg_inputs is None: + arg_inputs = [] + elif not isinstance(arg_inputs, collections.abc.Sequence): + arg_inputs = [arg_inputs] torchtrt_arg_inputs = prepare_inputs(arg_inputs) torchtrt_kwarg_inputs = prepare_inputs(kwarg_inputs) @@ -401,7 +414,7 @@ def cross_compile_for_windows( "'arg_inputs' and 'inputs' should not be used at the same time." ) - arg_inputs = inputs or arg_inputs + arg_inputs = inputs or arg_inputs # type: ignore[assignment] if kwarg_inputs is None: kwarg_inputs = {} @@ -501,7 +514,7 @@ def convert_method_to_trt_engine( raise AssertionError( "'arg_inputs' and 'inputs' should not be used at the same time." ) - arg_inputs = arg_inputs or inputs + arg_inputs = arg_inputs or inputs # type: ignore[assignment] module_type = _parse_module_type(module) target_ir = _get_target_fe(module_type, ir) @@ -821,8 +834,13 @@ def _all_are_input_objects(obj: Any) -> bool: "The explicit dynamic_shapes parameter takes precedence and Input shape specifications will be ignored." ) else: - inferred_dynamic_shapes = get_dynamic_shapes_args(module, arg_inputs) - inferred_dynamic_shapes.update(get_dynamic_shapes_kwargs(kwarg_inputs)) + dim_registry = build_dim_registry(arg_inputs, kwarg_inputs) + inferred_dynamic_shapes = get_dynamic_shapes_args( + module, arg_inputs, dim_registry + ) + inferred_dynamic_shapes.update( + get_dynamic_shapes_kwargs(kwarg_inputs, dim_registry) + ) if inferred_dynamic_shapes is not None: dynamic_shapes = inferred_dynamic_shapes @@ -830,8 +848,8 @@ def _all_are_input_objects(obj: Any) -> bool: f"Inferred dynamic_shapes from torch_tensorrt.Input objects with min/opt/max specifications: {dynamic_shapes}" ) - arg_tensors = tuple(get_torch_inputs(arg_inputs, default_device())) # type: ignore[arg-type] - kwarg_tensors = get_torch_inputs(kwarg_inputs, default_device()) # type: ignore[assignment] + arg_tensors = tuple(get_torch_inputs(arg_inputs, default_device())) + kwarg_tensors = get_torch_inputs(kwarg_inputs, default_device()) else: # Mixed case: some inputs are Tensors, some are Input objects diff --git a/py/torch_tensorrt/dynamo/_tracer.py b/py/torch_tensorrt/dynamo/_tracer.py index 0595c6a8f9..399559347b 100644 --- a/py/torch_tensorrt/dynamo/_tracer.py +++ b/py/torch_tensorrt/dynamo/_tracer.py @@ -65,7 +65,7 @@ def trace( raise AssertionError( "'arg_inputs' and 'inputs' should not be used at the same time." ) - arg_inputs = inputs or arg_inputs + arg_inputs = inputs if inputs is not None else arg_inputs if kwarg_inputs is None: kwarg_inputs = {} @@ -73,9 +73,12 @@ def trace( device = to_torch_device(kwargs.get("device", default_device())) torch_arg_inputs = get_torch_inputs(arg_inputs, device) torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) - # Constructing dynamic shape list as a nested dict - dynamic_shapes = get_dynamic_shapes_args(mod, arg_inputs) - dynamic_shapes.update(get_dynamic_shapes_kwargs(kwarg_inputs)) + # Build dynamic shapes from the Input objects. Inputs carrying shared_dims + # share a Dim across inputs via the registry; the rest get an independent + # per-input Dim. + dim_registry = build_dim_registry(arg_inputs, kwarg_inputs) + dynamic_shapes = get_dynamic_shapes_args(mod, arg_inputs, dim_registry) + dynamic_shapes.update(get_dynamic_shapes_kwargs(kwarg_inputs, dim_registry)) exp_program = export( mod, tuple(torch_arg_inputs), @@ -87,49 +90,109 @@ def trace( return exp_program -def get_dynamic_shapes_kwargs(inputs: Any) -> Union[dict[str, Any], list[Any]]: +def _collect_inputs(obj: Any) -> list[Input]: + """Flatten an arg/kwarg input structure into a list of Input objects.""" + if isinstance(obj, Input): + return [obj] + elif isinstance(obj, dict): + collected: list[Input] = [] + for v in obj.values(): + collected.extend(_collect_inputs(v)) + return collected + elif isinstance(obj, (list, tuple)): + collected = [] + for v in obj: + collected.extend(_collect_inputs(v)) + return collected + return [] + + +def build_dim_registry(arg_inputs: Any, kwarg_inputs: Any) -> dict[str, Any]: + """Build a ``{name: torch.export.Dim}`` registry from Input.shared_dims. + + The same name appearing on multiple inputs yields a single shared ``Dim`` + instance, so ``torch.export`` treats those axes as one symbol. Conflicting + (min, max) ranges for the same name are rejected. + """ + registry: dict[str, Any] = {} + bounds: dict[str, tuple[int, int]] = {} + for inp in _collect_inputs(arg_inputs) + _collect_inputs(kwarg_inputs): + shared_dims = getattr(inp, "shared_dims", None) + if not shared_dims or inp.shape_mode != Input._ShapeMode.DYNAMIC: + continue + assert isinstance(inp.shape, dict) + min_shape = inp.shape["min_shape"] + max_shape = inp.shape["max_shape"] + for axis, dim_name in shared_dims.items(): + lo, hi = int(min_shape[axis]), int(max_shape[axis]) + if dim_name in bounds: + if bounds[dim_name] != (lo, hi): + raise ValueError( + f"Dimension name '{dim_name}' is used with conflicting ranges " + f"{bounds[dim_name]} and {(lo, hi)}. A shared named dimension " + f"must have identical (min, max) on every input that uses it." + ) + else: + bounds[dim_name] = (lo, hi) + registry[dim_name] = Dim(dim_name, min=lo, max=hi) + return registry + + +def get_dynamic_shapes_kwargs( + inputs: Any, dim_registry: Optional[dict[str, Any]] = None +) -> Union[dict[str, Any], list[Any]]: if isinstance(inputs, dict): dynamic_shapes_kwarg = {} for k, v in inputs.items(): - dynamic_shapes_kwarg[k] = get_dynamic_shapes_kwargs(v) + dynamic_shapes_kwarg[k] = get_dynamic_shapes_kwargs(v, dim_registry) return dynamic_shapes_kwarg elif isinstance(inputs, Input): - return get_dynamic_shapes(inputs) + return get_dynamic_shapes(inputs, dim_registry) elif isinstance(inputs, (list, tuple)): dynamic_shapes = [] for input in inputs: - dynamic_shapes.append(get_dynamic_shapes(input)) + dynamic_shapes.append(get_dynamic_shapes(input, dim_registry)) return dynamic_shapes raise TypeError(f"Unknown type {type(inputs)}.") -def get_dynamic_shapes_args(mod: torch.nn.Module, inputs: Any) -> dict[str, Any]: +def get_dynamic_shapes_args( + mod: torch.nn.Module, inputs: Any, dim_registry: Optional[dict[str, Any]] = None +) -> dict[str, Any]: # dynamic_shape is a dict and cannot work without keys. Here we use position argument name # in forward function as the name args = list(signature(mod.forward).parameters.keys()) dynamic_shapes = {} for input, input_name in zip(inputs, args[: len(inputs)]): - dynamic_shapes[input_name] = get_dynamic_shapes(input) + dynamic_shapes[input_name] = get_dynamic_shapes(input, dim_registry) return dynamic_shapes -def get_dynamic_shapes(input: Input) -> dict[Any, Any]: +def get_dynamic_shapes( + input: Input, dim_registry: Optional[dict[str, Any]] = None +) -> dict[Any, Any]: if not isinstance(input, Input): # If the input is torch.Tensor, no dynamic is needed. Return empty dict return {} else: dynamic_dims = {} if input.shape_mode == Input._ShapeMode.DYNAMIC: + assert isinstance(input.shape, dict) min_shape = input.shape["min_shape"] opt_shape = input.shape["opt_shape"] max_shape = input.shape["max_shape"] + shared_dims = getattr(input, "shared_dims", None) or {} assert len(min_shape) == len(opt_shape) == len(max_shape) for dim in range(len(min_shape)): if min_shape[dim] == opt_shape[dim] == max_shape[dim]: continue + elif dim_registry is not None and dim in shared_dims: + # Named axis: reuse the shared Dim so axes with the same + # name across inputs become a single exported symbol. + dynamic_dims[dim] = dim_registry[shared_dims[dim]] else: dynamic_dims[dim] = Dim( input.name + "_" + str(dim), diff --git a/tests/py/dynamo/models/test_shared_dynamic_dim.py b/tests/py/dynamo/models/test_shared_dynamic_dim.py new file mode 100644 index 0000000000..814987332d --- /dev/null +++ b/tests/py/dynamo/models/test_shared_dynamic_dim.py @@ -0,0 +1,358 @@ +# type: ignore +""" +Tests for sharing a dynamic dimension across inputs via ``Input(shared_dims=...)``. + +Background: when a model takes multiple inputs whose dynamic axes must be +**equal at runtime** (e.g. HF encoders with ``input_ids`` / ``attention_mask`` +both shaped ``[B, S]``), naming each axis independently makes ``torch.export`` +mint an *independent* ``Dim`` per input. ``torch.export`` then fails its +constraint check for any forward() that broadcasts across those axes (here: +``embed(input_ids) * mask.unsqueeze(-1)``), raising ``ConstraintViolationError``. + +``Input(shared_dims={axis: name})`` lets the caller tag a dynamic axis with a +name; the same name across inputs is exported as a single shared ``Dim``. All +the dynamic-shape intent lives on the ``Input`` objects -- no separate +``dynamic_shapes`` argument and no ``torch.export`` knowledge required at the +call site. +""" + +import unittest + +import pytest +import torch +import torch.nn as nn +import torch_tensorrt as torchtrt +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity + +assertions = unittest.TestCase() + + +class _SharedBatchEncoder(nn.Module): + """HF-style encoder stand-in: two int64 inputs sharing the batch axis. + + The ``embed(input_ids) * mask.unsqueeze(-1)`` broadcast forces + ``input_ids.size(0) == attention_mask.size(0)`` -- the relationship a shared + named dimension expresses. + """ + + def __init__(self, vocab: int = 1024, hidden: int = 32): + super().__init__() + self.embed = nn.Embedding(vocab, hidden) + self.proj = nn.Linear(hidden, hidden) + + def forward(self, input_ids, attention_mask): + x = self.embed(input_ids) + mask = attention_mask.unsqueeze(-1).to(x.dtype) + return self.proj(x * mask) + + +def _named_input(name: str, seq: int = 16, batch_min: int = 1, batch_max: int = 4): + """A dynamic int64 Input whose batch axis (0) is named "B" for sharing.""" + return torchtrt.Input( + min_shape=(batch_min, seq), + opt_shape=(batch_max, seq), + max_shape=(batch_max, seq), + dtype=torch.int64, + name=name, + shared_dims={0: "B"}, + ) + + +@pytest.mark.unit +@pytest.mark.critical +def test_shared_dims_shared_batch_kwarg_inputs(): + """Shared batch axis declared via ``Input(shared_dims={0: "B"})`` on both + kwarg inputs -- same name => one exported symbol; engine matches eager.""" + model = _SharedBatchEncoder().eval().cuda() + + kwarg_inputs = { + "input_ids": _named_input("input_ids"), + "attention_mask": _named_input("attention_mask"), + } + + trt_mod = torchtrt.compile( + model, + ir="dynamo", + kwarg_inputs=kwarg_inputs, + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + ) + + # Sample at the optimization shape and at a smaller batch within the range. + for bs in (4, 2): + ids = torch.randint(0, 1024, (bs, 16), dtype=torch.int64, device="cuda") + mask = torch.ones((bs, 16), dtype=torch.int64, device="cuda") + + with torch.no_grad(): + ref = model(input_ids=ids, attention_mask=mask) + out = trt_mod(input_ids=ids, attention_mask=mask) + + cos_sim = cosine_similarity(ref, out) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + f"shared_dims shared batch (kwargs) out-of-tolerance at bs={bs}: cos_sim={cos_sim}", + ) + + +@pytest.mark.unit +def test_shared_dims_shared_batch_positional_inputs(): + """Same feature with positional ``inputs=[...]`` instead of kwargs.""" + model = _SharedBatchEncoder().eval().cuda() + + positional_inputs = [ + _named_input("input_ids"), + _named_input("attention_mask"), + ] + + trt_mod = torchtrt.compile( + model, + ir="dynamo", + inputs=positional_inputs, + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + ) + + for bs in (4, 2): + ids = torch.randint(0, 1024, (bs, 16), dtype=torch.int64, device="cuda") + mask = torch.ones((bs, 16), dtype=torch.int64, device="cuda") + + with torch.no_grad(): + ref = model(ids, mask) + out = trt_mod(ids, mask) + + cos_sim = cosine_similarity(ref, out) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + f"shared_dims shared batch (positional) out-of-tolerance at bs={bs}: cos_sim={cos_sim}", + ) + + +@pytest.mark.unit +def test_shared_dims_shared_batch_mixed_args_and_kwargs(): + """input_ids passed positionally, attention_mask as a kwarg; both share "B".""" + model = _SharedBatchEncoder().eval().cuda() + + trt_mod = torchtrt.compile( + model, + ir="dynamo", + inputs=[_named_input("input_ids")], + kwarg_inputs={"attention_mask": _named_input("attention_mask")}, + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + ) + + for bs in (4, 2): + ids = torch.randint(0, 1024, (bs, 16), dtype=torch.int64, device="cuda") + mask = torch.ones((bs, 16), dtype=torch.int64, device="cuda") + + with torch.no_grad(): + ref = model(ids, attention_mask=mask) + out = trt_mod(ids, attention_mask=mask) + + cos_sim = cosine_similarity(ref, out) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + f"shared_dims shared batch (mixed) out-of-tolerance at bs={bs}: cos_sim={cos_sim}", + ) + + +@pytest.mark.unit +def test_shared_dims_conflicting_ranges_raises(): + """Same name with different (min, max) across inputs is a user error.""" + from torch_tensorrt.dynamo._tracer import build_dim_registry + + seq = 16 + inputs = { + "input_ids": torchtrt.Input( + min_shape=(1, seq), + opt_shape=(4, seq), + max_shape=(4, seq), + dtype=torch.int64, + name="input_ids", + shared_dims={0: "B"}, + ), + "attention_mask": torchtrt.Input( + min_shape=(1, seq), + opt_shape=(8, seq), + max_shape=(8, seq), + dtype=torch.int64, + name="attention_mask", + shared_dims={0: "B"}, + ), + } + with assertions.assertRaises(ValueError): + build_dim_registry((), inputs) + + +@pytest.mark.unit +def test_shared_dims_rejected_on_static_axis(): + """Naming a static axis (min == max) is rejected at Input construction.""" + with assertions.assertRaises(ValueError): + torchtrt.Input( + min_shape=(1, 16), + opt_shape=(1, 16), + max_shape=(1, 16), + dtype=torch.int64, + name="x", + shared_dims={0: "B"}, + ) + + +@pytest.mark.unit +def test_shared_dims_rejected_on_out_of_range_axis(): + """An axis index outside the input's rank is rejected at construction.""" + with assertions.assertRaises(ValueError): + torchtrt.Input( + min_shape=(1, 16), + opt_shape=(4, 16), + max_shape=(4, 16), + dtype=torch.int64, + name="x", + shared_dims={5: "B"}, # rank is 2; axis 5 does not exist + ) + + +@pytest.mark.unit +def test_default_path_unchanged_for_static_inputs(): + """Sanity check: a fully static input with no shared_dims is unchanged.""" + + class StaticModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(8, 8) + + def forward(self, x): + return self.linear(x) + + model = StaticModel().eval().cuda() + trt_mod = torchtrt.compile( + model, + ir="dynamo", + inputs=[torchtrt.Input(shape=(2, 8), dtype=torch.float32, name="x")], + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + ) + x = torch.randn((2, 8), device="cuda") + with torch.no_grad(): + ref = model(x) + out = trt_mod(x) + assertions.assertTrue(cosine_similarity(ref, out) > COSINE_THRESHOLD) + + +# --------------------------------------------------------------------------- +# Namedtuple shape API tests +# --------------------------------------------------------------------------- +# +# The namedtuple API lets users name axes by using a namedtuple as the shape +# spec. Field names that appear on multiple inputs are automatically treated +# as a shared torch.export.Dim — no explicit shared_dims kwarg required. +# +# input_shape1 = namedtuple('S', ['n', 'c', 'h', 'w']) +# input_shape2 = namedtuple('S', ['c', 'seq']) +# Both have field 'c' → one shared Dim("c"). + +from collections import namedtuple + + +@pytest.mark.unit +@pytest.mark.critical +def test_namedtuple_shared_batch_positional_inputs(): + """Namedtuple field 'c' shared across two inputs — same as shared_dims={...:'c'}.""" + model = _SharedBatchEncoder().eval().cuda() + + # seq is static + S1 = namedtuple("shape", ["c", "seq"]) + + positional_inputs = [ + torchtrt.Input( + min_shape=S1(1, 16), + opt_shape=S1(4, 16), + max_shape=S1(4, 16), + dtype=torch.int64, + name="input_ids", + ), + torchtrt.Input( + min_shape=S1(1, 16), + opt_shape=S1(4, 16), + max_shape=S1(4, 16), + dtype=torch.int64, + name="attention_mask", + ), + ] + + trt_mod = torchtrt.compile( + model, + ir="dynamo", + inputs=positional_inputs, + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + ) + + for bs in (4, 2): + ids = torch.randint(0, 1024, (bs, 16), dtype=torch.int64, device="cuda") + mask = torch.ones((bs, 16), dtype=torch.int64, device="cuda") + with torch.no_grad(): + ref = model(ids, mask) + out = trt_mod(ids, mask) + cos_sim = cosine_similarity(ref, out) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + f"namedtuple shared batch (positional) out-of-tolerance at bs={bs}: cos_sim={cos_sim}", + ) + + +@pytest.mark.unit +def test_namedtuple_static_axes_skipped(): + """Static axes (min==max) in a namedtuple are not added to shared_dims.""" + S = namedtuple("shape", ["batch", "seq"]) + inp = torchtrt.Input( + min_shape=S(1, 16), + opt_shape=S(4, 16), + max_shape=S(4, 16), + dtype=torch.int64, + name="x", + ) + # 'seq' axis: min=max=16 → static, must not appear in shared_dims + assertions.assertNotIn(1, inp.shared_dims) + # 'batch' axis: min=1, max=4 → dynamic, must appear + assertions.assertIn(0, inp.shared_dims) + assertions.assertEqual(inp.shared_dims[0], "batch") + + +@pytest.mark.unit +def test_namedtuple_mismatched_fields_raises(): + """opt_shape namedtuple with different fields than min_shape is rejected.""" + S1 = namedtuple("shape", ["b", "c"]) + S2 = namedtuple("shape", ["b", "seq"]) + with assertions.assertRaises(ValueError): + torchtrt.Input( + min_shape=S1(1, 16), + opt_shape=S2(4, 16), + max_shape=S1(4, 16), + dtype=torch.int64, + name="x", + ) + + +@pytest.mark.unit +def test_namedtuple_and_shared_dims_together_raises(): + """Passing both a namedtuple shape and shared_dims kwarg is rejected.""" + S = namedtuple("shape", ["b", "c"]) + with assertions.assertRaises(ValueError): + torchtrt.Input( + min_shape=S(1, 16), + opt_shape=S(4, 16), + max_shape=S(4, 16), + dtype=torch.int64, + name="x", + shared_dims={0: "b"}, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])