Skip to content

Commit 4a4cec1

Browse files
committed
addresses the case when shape of upsample tensor contains ITensor
1 parent 863c869 commit 4a4cec1

File tree

2 files changed

+47
-3
lines changed

2 files changed

+47
-3
lines changed

py/torch_tensorrt/dynamo/conversion/impl/shape.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,41 @@ def get_shape_with_dynamic_shape(
123123
select_layer = ctx.net.add_select(condition_val, input_shape, scale_res)
124124
set_layer_name(select_layer, target, f"{name}_select")
125125
return select_layer.get_output(0)
126+
127+
128+
def to_trt_shape_tensor(
129+
ctx: ConversionContext, target: Target, name: str, shape_list: List[int | TRTTensor]
130+
) -> TRTTensor:
131+
"""
132+
Convert a mixed shape list (ints + ITensors) into a single ITensor.
133+
134+
Args:
135+
ctx: ConversionContext
136+
target: fx node target (used for naming).
137+
name (str): base name for layer naming.
138+
shape_list (list[int | ITensor]): list containing static ints and/or ITensors.
139+
140+
Returns:
141+
ITensor if shape_list contains any ITensors, else plain Python list of ints.
142+
"""
143+
trt_tensors = []
144+
145+
for i, s in enumerate(shape_list):
146+
if isinstance(s, int):
147+
const = ctx.net.add_constant((1,), np.array([s], dtype=np.int32))
148+
set_layer_name(const, target, f"{name}_dim{i}_const")
149+
trt_tensors.append(const.get_output(0))
150+
else:
151+
# Assume it's already an ITensor
152+
trt_tensors.append(s)
153+
154+
if trt_tensors:
155+
if any(not isinstance(s, int) for s in shape_list):
156+
# Concatenate everything into a single ITensor
157+
concat_layer = ctx.net.add_concatenation(trt_tensors)
158+
concat_layer.axis = 0
159+
set_layer_name(concat_layer, target, f"{name}_shape_concat")
160+
return concat_layer.get_output(0)
161+
162+
# If no ITensor found, return plain list of ints
163+
return shape_list

py/torch_tensorrt/dynamo/conversion/impl/upsample.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
has_dynamic_shape,
1010
set_layer_name,
1111
)
12-
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
12+
from torch_tensorrt.dynamo.conversion.impl.shape import (
13+
get_shape_with_dynamic_shape,
14+
to_trt_shape_tensor,
15+
)
1316

1417

1518
def upsample(
@@ -28,14 +31,17 @@ def upsample(
2831
if scale_factor is not None:
2932
layer.scales = [1.0, 1.0] + list(scale_factor)
3033
else:
31-
shape = list(input.shape)[:2] + list(size)
34+
shape = list(input.shape)[:2]
35+
if size is not None:
36+
shape += list(size)
3237
if has_dynamic_shape(shape):
3338
shape = get_shape_with_dynamic_shape(
3439
ctx, target, source_ir, name, shape, input
3540
)
3641
layer.set_input(1, shape)
3742
else:
38-
layer.shape = shape
43+
layer.shape = to_trt_shape_tensor(ctx, target, name, shape)
44+
layer.set_input(1, layer.shape)
3945

4046
if mode == "nearest":
4147
layer.resize_mode = trt.InterpolationMode.NEAREST

0 commit comments

Comments
 (0)