Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
skps
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed May 5, 2023
1 parent f61eaed commit 03c3a5f
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 13 deletions.
2 changes: 1 addition & 1 deletion tests/examples/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
marks=[
pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed"),
pytest.mark.skipif(not _ICEDATA_AVAILABLE, reason="icedata package isn't installed"),
pytest.xfail() # ToDo
pytest.xfail(), # ToDo
],
),
pytest.param(
Expand Down
2 changes: 2 additions & 0 deletions tests/image/semantic_segm/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
_FIFTYONE_AVAILABLE,
_MATPLOTLIB_AVAILABLE,
_PIL_AVAILABLE,
_SEGMENTATION_MODELS_AVAILABLE,
_TOPIC_IMAGE_AVAILABLE,
)
from flash.image import SemanticSegmentation, SemanticSegmentationData
Expand Down Expand Up @@ -375,6 +376,7 @@ def test_from_fiftyone(tmpdir):

@staticmethod
@pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.")
@pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")
def test_map_labels(tmpdir):
tmp_dir = Path(tmpdir)

Expand Down
5 changes: 2 additions & 3 deletions tests/image/semantic_segm/test_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import pytest
import torch

from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE, _TOPIC_IMAGE_AVAILABLE
from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE
from flash.image.segmentation import SemanticSegmentation
from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES
from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS
Expand All @@ -37,8 +37,7 @@ def test_semantic_segmentation_heads_registry(head):
assert res.shape[1] == 10


@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@unittest.mock.patch("flash.image.segmentation.heads.smp")
@pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")
def test_pretrained_weights(mock_smp):
mock_smp.create_model = unittest.mock.MagicMock()
available_weights = SemanticSegmentation.available_pretrained_weights("resnet18")
Expand Down
17 changes: 9 additions & 8 deletions tests/image/semantic_segm/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@

from flash import Trainer
from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _TOPIC_IMAGE_AVAILABLE, _TOPIC_SERVE_AVAILABLE
from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE, _TOPIC_IMAGE_AVAILABLE
from flash.image import SemanticSegmentation
from flash.image.segmentation.data import SemanticSegmentationData
from tests.helpers.task_tester import TaskTester


@pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")
class TestSemanticSegmentation(TaskTester):
task = SemanticSegmentation
task_args = (2,)
Expand Down Expand Up @@ -59,29 +60,29 @@ def example_test_sample(self):
return self.example_train_sample


@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")
def test_non_existent_backbone():
with pytest.raises(KeyError):
SemanticSegmentation(2, "i am never going to implement this lol")


@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")
def test_freeze():
model = SemanticSegmentation(2)
model.freeze()
for p in model.backbone.parameters():
assert p.requires_grad is False


@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")
def test_unfreeze():
model = SemanticSegmentation(2)
model.unfreeze()
for p in model.backbone.parameters():
assert p.requires_grad is True


@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")
def test_predict_tensor():
img = torch.rand(1, 3, 64, 64)
model = SemanticSegmentation(2, backbone="mobilenetv3_large_100")
Expand All @@ -93,7 +94,7 @@ def test_predict_tensor():
assert len(out[0][0][0]) == 64


@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")
def test_predict_numpy():
img = np.ones((1, 3, 64, 64))
model = SemanticSegmentation(2, backbone="mobilenetv3_large_100")
Expand All @@ -105,14 +106,14 @@ def test_predict_numpy():
assert len(out[0][0][0]) == 64


@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.")
@pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")
@mock.patch("flash._IS_TESTING", True)
def test_serve():
model = SemanticSegmentation(2)
model.eval()
model.serve()


@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")
def test_available_pretrained_weights():
assert SemanticSegmentation.available_pretrained_weights("resnet18") == ["imagenet", "ssl", "swsl"]
3 changes: 2 additions & 1 deletion tests/image/semantic_segm/test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
import torch

from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _TOPIC_IMAGE_AVAILABLE
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _SEGMENTATION_MODELS_AVAILABLE, _TOPIC_IMAGE_AVAILABLE
from flash.image.segmentation.output import FiftyOneSegmentationLabelsOutput, SegmentationLabelsOutput


@pytest.mark.skipif(not _TOPIC_IMAGE_AVAILABLE, "image libraries aren't installed.")
@pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")
class TestSemanticSegmentationLabelsOutput:
@staticmethod
def test_smoke():
Expand Down

0 comments on commit 03c3a5f

Please sign in to comment.