Skip to content

🐛 [Bug] __torch__.torch.classes.tensorrt.Engine (of Python compilation unit at: 0) does not have a field with name 'set_output_tensors_as_unowned' #4000

@zewenli98

Description

@zewenli98

Bug Description

The PR #3946 introduced the error below when using cpp runtime. Python runtime wasn't affected.

Traceback (most recent call last):
  File "/home/zewenl/Documents/pytorch/TensorRT/issues/issue_3981.py", line 59, in <module>
    out = trt_module(x)
          ^^^^^^^^^^^^^
  File "/home/zewenl/anaconda3/envs/py312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 414, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zewenl/anaconda3/envs/py312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zewenl/anaconda3/envs/py312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zewenl/anaconda3/envs/py312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 845, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zewenl/anaconda3/envs/py312/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 2196, in _call_user_compiler
    raise BackendCompilerFailed(
  File "/home/zewenl/anaconda3/envs/py312/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 2171, in _call_user_compiler
    compiled_fn = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zewenl/anaconda3/envs/py312/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py", line 156, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zewenl/anaconda3/envs/py312/lib/python3.12/site-packages/torch/__init__.py", line 2437, in __call__
    return self.compiler_fn(model_, inputs_, **self.kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zewenl/Documents/pytorch/TensorRT/py/torch_tensorrt/dynamo/backend/backends.py", line 47, in torch_tensorrt_backend
    return DEFAULT_BACKEND(gm, sample_inputs, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zewenl/Documents/pytorch/TensorRT/py/torch_tensorrt/dynamo/backend/backends.py", line 96, in aot_torch_tensorrt_aten_backend
    return _pretraced_backend(gm, sample_inputs, settings, engine_cache)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zewenl/Documents/pytorch/TensorRT/py/torch_tensorrt/dynamo/backend/backends.py", line 166, in _pretraced_backend
    trt_compiled = compile_module(
                   ^^^^^^^^^^^^^^^
  File "/home/zewenl/Documents/pytorch/TensorRT/py/torch_tensorrt/dynamo/_compiler.py", line 1091, in compile_module
    getattr(partitioned_module, target).set_output_tensors_as_unowned(True)
  File "/home/zewenl/Documents/pytorch/TensorRT/py/torch_tensorrt/_features.py", line 106, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/home/zewenl/Documents/pytorch/TensorRT/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py", line 366, in set_output_tensors_as_unowned
    self.engine.set_output_tensors_as_unowned(enabled)
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='tensorrt' raised:
AttributeError: __torch__.torch.classes.tensorrt.Engine (of Python compilation unit at: 0) does not have a field with name 'set_output_tensors_as_unowned'

To Reproduce

import torch
import torch.nn as nn
import torch_tensorrt
import logging
logging.basicConfig(level=logging.DEBUG)

torch.manual_seed(0)

class ExpandReshapeModel(nn.Module):
    def __init__(self, embed_dim: int=768):
        super().__init__()
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.embed_dim = embed_dim
        self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3)

    def forward(self, x: torch.Tensor):
        cls_token = self.cls_token.expand_as(x[:, :1])
        x = torch.cat([cls_token, x], dim=1)
        x = self.qkv_proj(x)
        reshaped_qkv = torch.unflatten(x, dim=-1, sizes=(3, 12, -1))
        return reshaped_qkv


model = ExpandReshapeModel(embed_dim=768).cuda().eval()

x = torch.randn(4, 196, 768).cuda()
torch._dynamo.mark_dynamic(x, index=0, min=2, max=32)
trt_module = torch.compile(model, backend="tensorrt", options={"min_block_size": 1, "use_python_runtime": False})

out = trt_module(x)
print(out)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions