We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 975b403 commit 8db4c74Copy full SHA for 8db4c74
py/torch_tensorrt/dynamo/conversion/impl/shape.py
@@ -143,7 +143,7 @@ def to_trt_shape_tensor(
143
trt_tensors = []
144
145
for i, s in enumerate(shape_list):
146
- if isinstance(s, int):
+ if isinstance(s, (int, torch.Tensor)):
147
const = ctx.net.add_constant((1,), np.array([s], dtype=np.int32))
148
set_layer_name(const, target, f"{name}_dim{i}_const")
149
trt_tensors.append(const.get_output(0))
0 commit comments