You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm currently working with torch2trt in our project, and due to some constraints, we need to implement padding manually. I've managed to reproduce the error using the following dummy model:
import torch
import torch.nn as nn
from torch.fx import symbolic_trace
from torch2trt import torch2trt
# Define BasicConv2d
class BasicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
# Define the backbone
class Backbone(nn.Module):
def __init__(self):
super(Backbone, self).__init__()
self.layer = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
def forward(self, x):
return self.layer(x)
# Define the complete model
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.backbone = Backbone()
self.pad_width = 2
self.head = BasicConv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.first_forward = True
def forward(self, x):
features = self.backbone(x) # 4D feature map
if self.first_forward:
self.first_forward = False
batch_size, channels, self.fm_height, self.fm_width = features.shape
padded_height = self.fm_height + 2 * self.pad_width
padded_width = self.fm_width + 2 * self.pad_width
self.features_padded = torch.zeros((batch_size, channels, padded_height, padded_width),
device=features.device)
self.features_padded[:, :, self.pad_width:self.pad_width + self.fm_height, self.pad_width:self.pad_width + self.fm_width] = features
output = self.head(self.features_padded) # Process the padded feature map
return output
# Instantiate and test the model
model = SimpleModel().eval().cuda() # Set to evaluation mode
# Input tensor
input_tensor = torch.randn(1, 3, 64, 64).cuda() # NCHW format: batch=1, channels=3, height=64, width=64
# First Forward for torch fx init calcs
output_tensor = model(input_tensor)
# Torch.fx symbolic tracing
try:
model_fx = symbolic_trace(model)
print("Torch.fx symbolic tracing successful.")
except Exception as e:
print("Error during Torch.fx symbolic tracing:", e)
# Convert to TensorRT using torch2trt
try:
model_trt = torch2trt(model_fx, [input_tensor])
print("torch2trt conversion successful.")
except Exception as e:
print("Error during torch2trt conversion:", e)
The issue occurs when converting the model using torch2trt. I also encountered the same errors when the input to torch2trt is the original model (not model_fx).
Here are the errors we are getting:
[01/06/2025-15:26:23] [TRT] [E] 3: head.conv:1:CONVOLUTION:GPU: at least 4 dimensions are required for input.
[01/06/2025-15:26:23] [TRT] [E] 4: [graphShapeAnalyzer.cpp::needTypeAndDimensions::2276] Error Code 4: Internal Error (head.conv:1:CONVOLUTION:GPU: output shape can not be computed)
[01/06/2025-15:26:23] [TRT] [E] 3: [network.cpp::addScaleNd::1151] Error Code 3: API Usage Error (Parameter check failed at: optimizer/api/network.cpp::addScaleNd::1151, condition: qdqScale || basicScale )
Error during torch2trt conversion: 'NoneType' object has no attribute 'get_output'
Hi,
I'm currently working with torch2trt in our project, and due to some constraints, we need to implement padding manually. I've managed to reproduce the error using the following dummy model:
The issue occurs when converting the model using torch2trt. I also encountered the same errors when the input to torch2trt is the original model (not model_fx).
Here are the errors we are getting:
Environment:
PyTorch version: 2.1.2+cu118
torch2trt version: 0.5.0
TensorRT version: 10.0.1
Any help would be greatly appreciated!
Thank you in advance!
The text was updated successfully, but these errors were encountered: