diff --git a/tests/test_cnn/test_wrappers.py b/tests/test_cnn/test_wrappers.py index 94c51ce15c..8c76ccbdd4 100644 --- a/tests/test_cnn/test_wrappers.py +++ b/tests/test_cnn/test_wrappers.py @@ -4,6 +4,8 @@ import pytest import torch import torch.nn as nn +from mmengine.utils import digit_version +from mmengine.utils.dl_utils import TORCH_VERSION from mmcv.cnn.bricks import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d, Linear, MaxPool2d, MaxPool3d) @@ -376,24 +378,19 @@ def test_nn_op_forward_called(): nn_module_forward.assert_called_with(x_normal) -@patch('mmcv.cnn.bricks.wrappers.TORCH_VERSION', (1, 10)) +@pytest.mark.skipif( + digit_version(TORCH_VERSION) < digit_version('1.10'), + reason='MaxPool2d and MaxPool3d will fail fx for torch<=1.9') def test_fx_compatibility(): - try: - from torch import fx - - # ensure the fx trace can pass the network - for Net in (MaxPool2d, MaxPool3d): - net = Net(1) - gm_module = fx.symbolic_trace(net) - print(gm_module.code) - for Net in (Linear, ): - net = Net(1, 1) - gm_module = fx.symbolic_trace(net) - print(gm_module.code) - for Net in (Conv2d, ConvTranspose2d, Conv3d, ConvTranspose3d): - net = Net(1, 1, 1) - gm_module = fx.symbolic_trace(net) - print(gm_module.code) - except ImportError: - # torch.fx might not be available - pass + from torch import fx + + # ensure the fx trace can pass the network + for Net in (MaxPool2d, MaxPool3d): + net = Net(1) + gm_module = fx.symbolic_trace(net) # noqa: F841 + for Net in (Linear, ): + net = Net(1, 1) + gm_module = fx.symbolic_trace(net) # noqa: F841 + for Net in (Conv2d, ConvTranspose2d, Conv3d, ConvTranspose3d): + net = Net(1, 1, 1) + gm_module = fx.symbolic_trace(net) # noqa: F841