Skip to content

Commit 975b403

Browse files
committed
adding test case and correcting a case
1 parent 4a4cec1 commit 975b403

File tree

3 files changed

+59
-14
lines changed

3 files changed

+59
-14
lines changed

py/torch_tensorrt/dynamo/conversion/impl/shape.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ def to_trt_shape_tensor(
132132
Convert a mixed shape list (ints + ITensors) into a single ITensor.
133133
134134
Args:
135-
ctx: ConversionContext
136-
target: fx node target (used for naming).
135+
ctx (ConversionContext): TensorRT ConversionContext object.
136+
target (Target): Target of fx node.
137137
name (str): base name for layer naming.
138138
shape_list (list[int | ITensor]): list containing static ints and/or ITensors.
139139
@@ -148,16 +148,14 @@ def to_trt_shape_tensor(
148148
set_layer_name(const, target, f"{name}_dim{i}_const")
149149
trt_tensors.append(const.get_output(0))
150150
else:
151-
# Assume it's already an ITensor
152151
trt_tensors.append(s)
153152

154-
if trt_tensors:
155-
if any(not isinstance(s, int) for s in shape_list):
156-
# Concatenate everything into a single ITensor
157-
concat_layer = ctx.net.add_concatenation(trt_tensors)
158-
concat_layer.axis = 0
159-
set_layer_name(concat_layer, target, f"{name}_shape_concat")
160-
return concat_layer.get_output(0)
153+
if any(not isinstance(s, int) for s in shape_list):
154+
# Concatenate everything into a single ITensor if there are any ITensors/Tensors
155+
concat_layer = ctx.net.add_concatenation(trt_tensors)
156+
concat_layer.axis = 0
157+
set_layer_name(concat_layer, target, f"{name}_shape_concat")
158+
return concat_layer.get_output(0)
161159

162160
# If no ITensor found, return plain list of ints
163161
return shape_list

py/torch_tensorrt/dynamo/conversion/impl/upsample.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,19 @@ def upsample(
3232
layer.scales = [1.0, 1.0] + list(scale_factor)
3333
else:
3434
shape = list(input.shape)[:2]
35-
if size is not None:
36-
shape += list(size)
35+
if size is not None:
36+
shape += list(size)
3737
if has_dynamic_shape(shape):
3838
shape = get_shape_with_dynamic_shape(
3939
ctx, target, source_ir, name, shape, input
4040
)
4141
layer.set_input(1, shape)
4242
else:
43-
layer.shape = to_trt_shape_tensor(ctx, target, name, shape)
44-
layer.set_input(1, layer.shape)
43+
trt_shape = to_trt_shape_tensor(ctx, target, name, shape)
44+
if isinstance(trt_shape, list):
45+
layer.shape = trt_shape
46+
else:
47+
layer.set_input(1, trt_shape)
4548

4649
if mode == "nearest":
4750
layer.resize_mode = trt.InterpolationMode.NEAREST

tests/py/dynamo/conversion/test_upsample_aten.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,50 @@ def forward(self, x):
296296
]
297297
self.run_test_with_dynamic_shape(TestModule(), input_specs)
298298

299+
@parameterized.expand(
300+
[
301+
([torch.tensor(3), 3], None),
302+
(None, [torch.tensor(0.5), 1.5]),
303+
]
304+
)
305+
def test_nearest2d_mixed_dynamic_shape(self, output_size, scale_factors):
306+
class TestModule(torch.nn.Module):
307+
def forward(self, x):
308+
out_size = output_size
309+
scale = scale_factors
310+
311+
return torch.ops.aten.upsample_nearest2d.vec(x, out_size, scale)
312+
313+
input_specs = [
314+
Input(
315+
min_shape=(1, 1, 1, 1),
316+
opt_shape=(5, 5, 5, 5),
317+
max_shape=(9, 9, 9, 9),
318+
dtype=torch.float32,
319+
)
320+
]
321+
self.run_test_with_dynamic_shape(TestModule(), input_specs)
322+
323+
@parameterized.expand(
324+
[
325+
# Mix of Tensor and int in output_size
326+
([torch.tensor(3), 3], None),
327+
# Mix of Tensor and float in scale_factors
328+
(None, [torch.tensor(0.5), 1.5]),
329+
]
330+
)
331+
def test_nearest2d_mixed_static_input(self, output_size, scale_factors):
332+
class TestModule(torch.nn.Module):
333+
def forward(self, x):
334+
out_size = output_size
335+
scale = scale_factors
336+
return torch.ops.aten.upsample_nearest2d.vec(x, out_size, scale)
337+
338+
input_size = [7, 7] # H, W
339+
inputs = [torch.randn([1, 1] + input_size)] # shape [1, 1, 7, 7]
340+
341+
self.run_test(TestModule(), inputs)
342+
299343

300344
if __name__ == "__main__":
301345
run_tests()

0 commit comments

Comments
 (0)