Skip to content

Commit

Permalink
modify skip test conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao committed Aug 3, 2023
1 parent 616fc72 commit fc883ec
Showing 1 changed file with 17 additions and 20 deletions.
37 changes: 17 additions & 20 deletions tests/test_cnn/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import pytest
import torch
import torch.nn as nn
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION

from mmcv.cnn.bricks import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
Linear, MaxPool2d, MaxPool3d)
Expand Down Expand Up @@ -376,24 +378,19 @@ def test_nn_op_forward_called():
nn_module_forward.assert_called_with(x_normal)


@patch('mmcv.cnn.bricks.wrappers.TORCH_VERSION', (1, 10))
@pytest.mark.skipif(
digit_version(TORCH_VERSION) < digit_version('1.10'),
reason='MaxPool2d and MaxPool3d will fail fx for torch<=1.9')
def test_fx_compatibility():
try:
from torch import fx

# ensure the fx trace can pass the network
for Net in (MaxPool2d, MaxPool3d):
net = Net(1)
gm_module = fx.symbolic_trace(net)
print(gm_module.code)
for Net in (Linear, ):
net = Net(1, 1)
gm_module = fx.symbolic_trace(net)
print(gm_module.code)
for Net in (Conv2d, ConvTranspose2d, Conv3d, ConvTranspose3d):
net = Net(1, 1, 1)
gm_module = fx.symbolic_trace(net)
print(gm_module.code)
except ImportError:
# torch.fx might not be available
pass
from torch import fx

# ensure the fx trace can pass the network
for Net in (MaxPool2d, MaxPool3d):
net = Net(1)
gm_module = fx.symbolic_trace(net) # noqa: F841
for Net in (Linear, ):
net = Net(1, 1)
gm_module = fx.symbolic_trace(net) # noqa: F841
for Net in (Conv2d, ConvTranspose2d, Conv3d, ConvTranspose3d):
net = Net(1, 1, 1)
gm_module = fx.symbolic_trace(net) # noqa: F841

0 comments on commit fc883ec

Please sign in to comment.