Skip to content
103 changes: 86 additions & 17 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,6 +1240,10 @@ def test_kernel_video(self):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
functools.partial(make_image_cvcuda, batch_dims=(1,)),
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand All @@ -1255,6 +1259,11 @@ def test_functional(self, make_input):
(F.horizontal_flip_image, torch.Tensor),
(F._geometry._horizontal_flip_image_pil, PIL.Image.Image),
(F.horizontal_flip_image, tv_tensors.Image),
pytest.param(
F._geometry._horizontal_flip_image_cvcuda,
cvcuda.Tensor,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, oops, that's a hard dependency and it crashes on the jobs where CVCUDA isn't available. Let's hack something for now, we'll think of a better way to handle that later:

Suggested change
cvcuda.Tensor,
None,

Then in the code below:

def test_functional_signature(self, kernel, input_type):
    if kernel is F._geometry._horizontal_flip_image_cvcuda:
        input_type = _import_cvcuda().Tensor

Copy link

@justincdavis justincdavis Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have been using the string "cvcuda.Tensor" and then checking input_type == "cvcuda.Tensor". If we want some kind of better engineered solution as opposed to None or strings, we could make a cvcuda_tensor type which is None if CV-CUDA isn't installed.

Maybe something like:

cvcuda_tensor = None
if CVCUDA_AVAILABLE:
    cvcuda_tensor = _import_cvcuda().Tensor

This at least keeps the naming consistent and drops the if statements in the signature tests.

marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
(F.horizontal_flip_bounding_boxes, tv_tensors.BoundingBoxes),
(F.horizontal_flip_mask, tv_tensors.Mask),
(F.horizontal_flip_video, tv_tensors.Video),
Expand All @@ -1270,6 +1279,10 @@ def test_functional_signature(self, kernel, input_type):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
functools.partial(make_image_cvcuda, batch_dims=(1,)),
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand All @@ -1283,13 +1296,32 @@ def test_transform(self, make_input, device):
@pytest.mark.parametrize(
"fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)]
)
def test_image_correctness(self, fn):
image = make_image(dtype=torch.uint8, device="cpu")

actual = fn(image)
expected = F.to_image(F.horizontal_flip(F.to_pil_image(image)))
@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
functools.partial(make_image_cvcuda, batch_dims=(1,)),
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
],
)

torch.testing.assert_close(actual, expected)
def test_image_correctness(self, fn, make_input):
image = make_input()
actual = fn(image)
if isinstance(image, cvcuda.Tensor):
# For CVCUDA input
expected = F.horizontal_flip(F.cvcuda_to_tensor(image))
print("actual is ", F.cvcuda_to_tensor(actual))
print("expected is ", expected)
assert_equal(F.cvcuda_to_tensor(actual), expected)

else:
# For PIL/regular image input
expected = F.to_image(F.horizontal_flip(F.to_pil_image(image)))
assert_equal(actual, expected)

def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes: tv_tensors.BoundingBoxes):
affine_matrix = np.array(
Expand Down Expand Up @@ -1345,6 +1377,10 @@ def test_keypoints_correctness(self, fn):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
functools.partial(make_image_cvcuda, batch_dims=(1,)),
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand All @@ -1354,12 +1390,13 @@ def test_keypoints_correctness(self, fn):
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform_noop(self, make_input, device):
input = make_input(device=device)

transform = transforms.RandomHorizontalFlip(p=0)

output = transform(input)
if isinstance(input, cvcuda.Tensor):
assert_equal(F.cvcuda_to_tensor(output), F.cvcuda_to_tensor(input))
else:
assert_equal(output, input)

assert_equal(output, input)


class TestAffine:
Expand Down Expand Up @@ -1856,6 +1893,10 @@ def test_kernel_video(self):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
functools.partial(make_image_cvcuda, batch_dims=(1,)),
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand All @@ -1871,6 +1912,11 @@ def test_functional(self, make_input):
(F.vertical_flip_image, torch.Tensor),
(F._geometry._vertical_flip_image_pil, PIL.Image.Image),
(F.vertical_flip_image, tv_tensors.Image),
pytest.param(
F._geometry._vertical_flip_image_cvcuda,
cvcuda.Tensor,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
(F.vertical_flip_bounding_boxes, tv_tensors.BoundingBoxes),
(F.vertical_flip_mask, tv_tensors.Mask),
(F.vertical_flip_video, tv_tensors.Video),
Expand All @@ -1886,6 +1932,10 @@ def test_functional_signature(self, kernel, input_type):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
functools.partial(make_image_cvcuda, batch_dims=(1,)),
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand All @@ -1897,13 +1947,28 @@ def test_transform(self, make_input, device):
check_transform(transforms.RandomVerticalFlip(p=1), make_input(device=device))

@pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
def test_image_correctness(self, fn):
image = make_image(dtype=torch.uint8, device="cpu")
@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
functools.partial(make_image_cvcuda, batch_dims=(1,)),
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
],
)

def test_image_correctness(self, fn, make_input):
image = make_input()
actual = fn(image)
expected = F.to_image(F.vertical_flip(F.to_pil_image(image)))

torch.testing.assert_close(actual, expected)
if isinstance(image, cvcuda.Tensor):
# For CVCUDA input
expected = F.vertical_flip(F.cvcuda_to_tensor(image))
assert_equal(F.cvcuda_to_tensor(actual), expected)
else:
# For PIL/regular image input
expected = F.to_image(F.vertical_flip(F.to_pil_image(image)))
assert_equal(actual, expected)

def _reference_vertical_flip_bounding_boxes(self, bounding_boxes: tv_tensors.BoundingBoxes):
affine_matrix = np.array(
Expand Down Expand Up @@ -1955,6 +2020,10 @@ def test_keypoints_correctness(self, fn):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
functools.partial(make_image_cvcuda, batch_dims=(1,)),
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand All @@ -1964,12 +2033,12 @@ def test_keypoints_correctness(self, fn):
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform_noop(self, make_input, device):
input = make_input(device=device)

transform = transforms.RandomVerticalFlip(p=0)

output = transform(input)

assert_equal(output, input)
if isinstance(input, cvcuda.Tensor):
assert_equal(F.cvcuda_to_tensor(output), F.cvcuda_to_tensor(input))
else:
assert_equal(output, input)


class TestRotate:
Expand Down
12 changes: 11 additions & 1 deletion torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torchvision.ops.boxes import box_iou
from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
from torchvision.transforms.v2.functional._utils import _FillType
from torchvision.transforms.v2.functional._utils import _FillType, _import_cvcuda, _is_cvcuda_available

from ._transform import _RandomApplyTransform
from ._utils import (
Expand All @@ -30,6 +30,9 @@
query_size,
)

CVCUDA_AVAILABLE = _is_cvcuda_available()
if CVCUDA_AVAILABLE:
cvcuda = _import_cvcuda()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this import needed? I feel we are not using cvcuda module directly anywhere in this file?


class RandomHorizontalFlip(_RandomApplyTransform):
"""Horizontally flip the input with a given probability.
Expand All @@ -45,6 +48,9 @@ class RandomHorizontalFlip(_RandomApplyTransform):

_v1_transform_cls = _transforms.RandomHorizontalFlip

if CVCUDA_AVAILABLE:
_transformed_types = (torch.Tensor, PIL.Image.Image, cvcuda.Tensor)

def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
return self._call_kernel(F.horizontal_flip, inpt)

Expand All @@ -63,6 +69,10 @@ class RandomVerticalFlip(_RandomApplyTransform):

_v1_transform_cls = _transforms.RandomVerticalFlip

if CVCUDA_AVAILABLE:
_transformed_types = (torch.Tensor, PIL.Image.Image, cvcuda.Tensor)


def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
return self._call_kernel(F.vertical_flip, inpt)

Expand Down
24 changes: 22 additions & 2 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numbers
import warnings
from collections.abc import Sequence
from typing import Any, Optional, Union
from typing import Any, Optional, TYPE_CHECKING, Union

import PIL.Image
import torch
Expand All @@ -26,7 +26,13 @@

from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format

from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal
from ._utils import _FillTypeJIT, _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_five_ten_crop_kernel_internal, _register_kernel_internal

CVCUDA_AVAILABLE = _is_cvcuda_available()
if TYPE_CHECKING:
import cvcuda
if CVCUDA_AVAILABLE:
cvcuda = _import_cvcuda()


def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode:
Expand Down Expand Up @@ -61,6 +67,12 @@ def horizontal_flip_image(image: torch.Tensor) -> torch.Tensor:
def _horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
return _FP.hflip(image)

def _horizontal_flip_image_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor":

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a bit of a nitpick, but could we rename the function to _horizontal_flip_cvcuda, CV-CUDA only operates on one datatype so the extra "image" in the funcname does not add value IMO. Removing it also mirrors the cvcuda_to_tensor and tensor_to_cvcuda functions

Copy link
Member

@NicolasHug NicolasHug Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the cvcuda_to_tensor and tensor_to_cvcuda functions are a bit of outliers in that sense, but most other kernels specify the nature of the input they work on. We have e.g.

  • horizontal_flip_image for tensors and tv_tensor.Image
  • _horizontal_flip_image_pil
  • horizontal_flip_mask
  • horizontal_flip_bounding_boxes
  • etc.

The CVCUDA backend is basically of the same nature as the PIL backend. So It makes sense to keep it named _horizontal_flip_cvcuda (EDIT: meant _horizontal_flip_image_cvcuda!!) IMO., like we have _horizontal_flip_image_pil.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug just to be sure, you are saying it makes sense to keep it named _horizontal_flip_cvcuda, I guess you mean it makes sense to keep it named _horizontal_flip_image_cvcuda?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks for catching! I'll edit above to avoid further confusion

return _import_cvcuda().flip(image, flipCode=1)


if CVCUDA_AVAILABLE:
_horizontal_flip_image_cvcuda_registered = _register_kernel_internal(horizontal_flip, _import_cvcuda().Tensor)(_horizontal_flip_image_cvcuda)

@_register_kernel_internal(horizontal_flip, tv_tensors.Mask)
def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -150,6 +162,14 @@ def _vertical_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
return _FP.vflip(image)


def _vertical_flip_image_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor":
return _import_cvcuda().flip(image, flipCode=0)


if CVCUDA_AVAILABLE:
_vertical_flip_image_cvcuda_registered = _register_kernel_internal(vertical_flip, _import_cvcuda().Tensor)(_vertical_flip_image_cvcuda)


@_register_kernel_internal(vertical_flip, tv_tensors.Mask)
def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
return vertical_flip_image(mask)
Expand Down
Loading