diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 06d6514770d..e45bda4a393 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2520,14 +2520,29 @@ def test_errors(self): with pytest.raises(TypeError, match="Argument transforms should be a sequence of callables"): cls(lambda x: x) - with pytest.raises(ValueError, match="at least one transform"): - transforms.Compose([]) + for cls in ( + transforms.Compose, + transforms.RandomApply, + transforms.RandomChoice, + transforms.RandomOrder, + ): + + with pytest.raises(ValueError, match="at least one transform"): + cls([]) for p in [-1, 2]: with pytest.raises(ValueError, match=re.escape("value in the interval [0.0, 1.0]")): transforms.RandomApply([lambda x: x], p=p) - for transforms_, p in [([lambda x: x], []), ([], [1.0])]: + for transforms_, p in [ + ([lambda x: x], []), + ( + [lambda x: x, lambda x: x], + [ + 1.0, + ], + ), + ]: with pytest.raises(ValueError, match="Length of p doesn't match the number of transforms"): transforms.RandomChoice(transforms_, p=p) diff --git a/torchvision/transforms/v2/_container.py b/torchvision/transforms/v2/_container.py index d63dd9f8f46..95ec25a22f8 100644 --- a/torchvision/transforms/v2/_container.py +++ b/torchvision/transforms/v2/_container.py @@ -87,6 +87,8 @@ def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: floa if not isinstance(transforms, (Sequence, nn.ModuleList)): raise TypeError("Argument transforms should be a sequence of callables or a `nn.ModuleList`") + elif not transforms: + raise ValueError("Pass at least one transform") self.transforms = transforms if not (0.0 <= p <= 1.0): @@ -133,7 +135,8 @@ def __init__( ) -> None: if not isinstance(transforms, Sequence): raise TypeError("Argument transforms should be a sequence of callables") - + elif not transforms: + raise ValueError("Pass at least one transform") if p is None: p = [1] * len(transforms) elif len(p) != len(transforms): @@ -163,6 +166,8 @@ class RandomOrder(Transform): def __init__(self, transforms: Sequence[Callable]) -> None: if not isinstance(transforms, Sequence): raise TypeError("Argument transforms should be a sequence of callables") + elif not transforms: + raise ValueError("Pass at least one transform") super().__init__() self.transforms = transforms