diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index a757cf023e..53f7fb0ad6 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2550,15 +2550,7 @@ def aten_ops_cdist_forward( def avg_pool_param_validator(pool_node: Node) -> bool: - ceil_mode = args_bounds_check(pool_node.args, 4, False) divisor_override = args_bounds_check(pool_node.args, 6) - - if ceil_mode is not False: - _LOGGER.debug( - f"Currently we don't support specifying ceil_mode, got ceil_mode={ceil_mode}." - ) - return False - if divisor_override is not None: _LOGGER.debug( f"Currently we don't support divisor_override, got divisor_override={divisor_override}." @@ -2694,17 +2686,14 @@ def topk_sort_validator(k: int) -> bool: def max_pool_param_validator(pool_node: Node) -> bool: dilation = args_bounds_check(pool_node.args, 4, 1) - ceil_mode = args_bounds_check(pool_node.args, 5, False) - if dilation != 1: - _LOGGER.debug(f"Currently we don't support dilation, got dilation={dilation}.") - return False + if not isinstance(dilation, (list, tuple)): + dilation = (dilation,) - if ceil_mode is not False: - _LOGGER.debug( - f"Currently we don't support specifying ceil_mode, got ceil_mode={ceil_mode}." - ) - return False + for dil in dilation: + if dil != 1: + _LOGGER.debug("Currently we don't support dilation > 1 at any dimension.") + return False return True diff --git a/py/torch_tensorrt/dynamo/conversion/impl/pool.py b/py/torch_tensorrt/dynamo/conversion/impl/pool.py index 4e18aaaef2..bc70d59527 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/pool.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/pool.py @@ -30,8 +30,9 @@ def avg_poolNd( count_include_pad: bool = True, divisor_override: Optional[int] = None, ) -> TRTTensor: - if ceil_mode is not False: - raise RuntimeError("ceil_mode is not yet supported!") + padding_mode = trt.PaddingMode.EXPLICIT_ROUND_DOWN + if ceil_mode: + padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP if divisor_override is not None: raise RuntimeError("divisor_override is not yet supported!") @@ -57,6 +58,7 @@ def avg_poolNd( pool_layer.stride_nd = stride pool_layer.padding_nd = padding pool_layer.average_count_excludes_padding = not count_include_pad + pool_layer.padding_mode = padding_mode set_layer_name(pool_layer, target, name, source_ir) return pool_layer.get_output(0) @@ -77,11 +79,9 @@ def max_poolNd( if has_dynamic_shape(input.shape): assert input.shape[1] != -1, "Channel dim can't be dynamic for pooling." - if dilation != 1: - raise RuntimeError("dilation is not yet supported!") - - if ceil_mode is not False: - raise RuntimeError("ceil_mode is not yet supported!") + padding_mode = trt.PaddingMode.EXPLICIT_ROUND_DOWN + if ceil_mode: + padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP dim = len(kernel_size) @@ -103,6 +103,7 @@ def max_poolNd( pool_layer.stride_nd = stride pool_layer.padding_nd = padding + pool_layer.padding_mode = padding_mode set_layer_name(pool_layer, target, name, source_ir) return pool_layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_pool_aten.py b/tests/py/dynamo/conversion/test_pool_aten.py index 29fdf30480..38746f23b3 100644 --- a/tests/py/dynamo/conversion/test_pool_aten.py +++ b/tests/py/dynamo/conversion/test_pool_aten.py @@ -15,6 +15,8 @@ class TestPoolConverter(DispatchTestCase): ((4,), (1,), (1,)), ((5,), (2,), (0,)), ((7,), (2,), (1,)), + ((3,), (1,), (1,), 0, True), + ((7,), (2,), (1,), 0, True), ] ) def test_avg_pool1d( @@ -44,8 +46,11 @@ def forward(self, x): (3, 1, 1), ((2, 2), [], (1, 0)), ((4, 3), (1, 1), (1, 1)), + ((4, 3), (1, 1), (1, 1), True), ((5, 4), (2, 1), (1, 0)), + ((5, 4), (2, 1), (1, 0), True), ((7, 7), (1, 2), (0, 1)), + ((7, 7), (1, 2), (0, 1), True), ] ) def test_avg_pool2d( @@ -70,7 +75,7 @@ def forward(self, x): ) inputs = [torch.randn(1, 3, 32, 32)] - self.run_test(TestModule(), inputs, use_dynamo_tracer=True) + self.run_test(TestModule(), inputs, rtol=5e-03, atol=5e-03, use_dynamo_tracer=True) @parameterized.expand( [ @@ -80,6 +85,8 @@ def forward(self, x): ((4, 3, 2), (1, 1, 1), (1, 1, 0)), ((5, 4, 3), (2, 1, 2), (1, 0, 1)), ((7, 7, 7), (1, 2, 1), (0, 1, 1)), + ((7, 7, 7), (1, 2, 1), (0, 1, 1), True), + ((5, 4, 3), (2, 1, 2), (1, 0, 1), True), ] ) def test_avg_pool3d( @@ -168,6 +175,16 @@ def forward(self, x): (1, 1), (1, 1), ), + ( + (1, 1, 1, 1), + (2, 2, 2, 2), + (3, 3, 3, 3), + torch.float, + (3, 3), + (1, 1), + (1, 1), + True + ), ] ) def test_dynamic_shape_pool2d( @@ -258,6 +275,7 @@ def forward(self, x): ((4,), (1,), (1,)), ((5,), (2,), (0,)), ((7,), (2,), (1,)), + ((7,), (2,), (1,), 1, True), ] ) def test_max_pool1d( @@ -290,6 +308,9 @@ def forward(self, x): ((4, 3), (1, 1), (1, 1)), ((5, 4), (2, 1), (1, 0)), ((7, 7), (1, 2), (0, 1)), + ((4, 3), (1, 1), (1, 1), 1, True), + ((5, 4), (2, 1), (1, 0), 1, True), + ((7, 7), (1, 2), (0, 1), 1, True), ] ) def test_max_pool2d( @@ -322,6 +343,7 @@ def forward(self, x): ((4, 3, 2), (1, 1, 1), (1, 1, 0)), ((5, 4, 3), (2, 1, 2), (1, 0, 1)), ((7, 7, 7), (1, 2, 1), (0, 1, 1)), + ((7, 7, 7), (1, 2, 1), (0, 1, 1), 1, True), ] ) def test_max_pool3d(