Skip to content
36 changes: 30 additions & 6 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision import io, tv_tensors
from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import to_cvcuda_tensor, to_image, to_pil_image
from torchvision.transforms.v2.functional import cvcuda_to_tensor, to_cvcuda_tensor, to_image, to_pil_image
from torchvision.transforms.v2.functional._utils import _import_cvcuda, _is_cvcuda_available
from torchvision.utils import _Image_fromarray


Expand Down Expand Up @@ -284,8 +285,27 @@ def __init__(
mae=False,
**other_parameters,
):
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
actual, expected = (to_image(input) for input in [actual, expected])
# Convert PIL images to tv_tensors.Image (regardless of what the other is)
if isinstance(actual, PIL.Image.Image):
actual = to_image(actual)
if isinstance(expected, PIL.Image.Image):
expected = to_image(expected)

if _is_cvcuda_available():
cvcuda = _import_cvcuda()

if isinstance(actual, cvcuda.Tensor):
actual = cvcuda_to_tensor(actual) # No import needed here anymore!
# Remove batch dimension if it's 1 for easier comparison
if actual.shape[0] == 1:
actual = actual[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems unnecessary, we should be able to compare tensors where the batch dim is 1. Try to remove it, if it doesn't work for any reason let me know.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EDIT: ah, OK, it's for when we compare a 3D PIL image to a 4D cvcuda tensor. That's... fine. Let's explain why then (addition in bold):

Remove batch dimension if it's 1 for easier comparison against 3D PIL images

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we able to split the logic to drop batch and move to cpu into a helper function? We do the same thing in the test itself, and I think it would improve clarity to have an explicit helper.

Line in reference

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a 2-line helper, I'd say let's leave it out for now. We might need it later, but it'll be more obvious to me when I get to see more code and more usage. There might be some refactoring opportunities that you're seeing @justincdavis and that I'm not yet seeing - sorry for that, I'm sure things will be easier to decide for me once I've reviewed more of the coming PRs.

actual = actual.cpu()
if isinstance(expected, cvcuda.Tensor):
expected = cvcuda_to_tensor(expected)
# Remove batch dimension if it's 1 for easier comparison
if expected.shape[0] == 1:
expected = expected[0]
expected = expected.cpu()

super().__init__(actual, expected, **other_parameters)
self.mae = mae
Expand Down Expand Up @@ -400,8 +420,8 @@ def make_image_pil(*args, **kwargs):
return to_pil_image(make_image(*args, **kwargs))


def make_image_cvcuda(*args, **kwargs):
return to_cvcuda_tensor(make_image(*args, **kwargs))
def make_image_cvcuda(*args, batch_dims=(1,), **kwargs):
return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs))


def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"):
Expand Down Expand Up @@ -541,5 +561,9 @@ def ignore_jit_no_profile_information_warning():
# with varying `INT1` and `INT2`. Since these are uninteresting for us and only clutter the test summary, we ignore
# them.
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message=re.escape("operator() profile_node %"), category=UserWarning)
warnings.filterwarnings(
"ignore",
message=re.escape("operator() profile_node %"),
category=UserWarning,
)
yield
84 changes: 66 additions & 18 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,6 +1240,10 @@ def test_kernel_video(self):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand All @@ -1255,6 +1259,11 @@ def test_functional(self, make_input):
(F.horizontal_flip_image, torch.Tensor),
(F._geometry._horizontal_flip_image_pil, PIL.Image.Image),
(F.horizontal_flip_image, tv_tensors.Image),
pytest.param(
F._geometry._horizontal_flip_image_cvcuda,
cvcuda.Tensor,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, oops, that's a hard dependency and it crashes on the jobs where CVCUDA isn't available. Let's hack something for now, we'll think of a better way to handle that later:

Suggested change
cvcuda.Tensor,
None,

Then in the code below:

def test_functional_signature(self, kernel, input_type):
    if kernel is F._geometry._horizontal_flip_image_cvcuda:
        input_type = _import_cvcuda().Tensor

Copy link

@justincdavis justincdavis Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have been using the string "cvcuda.Tensor" and then checking input_type == "cvcuda.Tensor". If we want some kind of better engineered solution as opposed to None or strings, we could make a cvcuda_tensor type which is None if CV-CUDA isn't installed.

Maybe something like:

cvcuda_tensor = None
if CVCUDA_AVAILABLE:
    cvcuda_tensor = _import_cvcuda().Tensor

This at least keeps the naming consistent and drops the if statements in the signature tests.

marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
(F.horizontal_flip_bounding_boxes, tv_tensors.BoundingBoxes),
(F.horizontal_flip_mask, tv_tensors.Mask),
(F.horizontal_flip_video, tv_tensors.Video),
Expand All @@ -1270,6 +1279,10 @@ def test_functional_signature(self, kernel, input_type):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand All @@ -1283,13 +1296,23 @@ def test_transform(self, make_input, device):
@pytest.mark.parametrize(
"fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)]
)
def test_image_correctness(self, fn):
image = make_image(dtype=torch.uint8, device="cpu")

@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
],
)
def test_image_correctness(self, fn, make_input):
image = make_input()
actual = fn(image)
expected = F.to_image(F.horizontal_flip(F.to_pil_image(image)))

torch.testing.assert_close(actual, expected)
if make_input is make_image_cvcuda:
image = F.cvcuda_to_tensor(image)[0].cpu()
expected = F.horizontal_flip(F.to_pil_image(image))
assert_equal(actual, expected)

def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes: tv_tensors.BoundingBoxes):
affine_matrix = np.array(
Expand Down Expand Up @@ -1345,6 +1368,10 @@ def test_keypoints_correctness(self, fn):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand All @@ -1354,11 +1381,8 @@ def test_keypoints_correctness(self, fn):
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform_noop(self, make_input, device):
input = make_input(device=device)

transform = transforms.RandomHorizontalFlip(p=0)

output = transform(input)

assert_equal(output, input)


Expand Down Expand Up @@ -1856,6 +1880,10 @@ def test_kernel_video(self):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand All @@ -1871,6 +1899,11 @@ def test_functional(self, make_input):
(F.vertical_flip_image, torch.Tensor),
(F._geometry._vertical_flip_image_pil, PIL.Image.Image),
(F.vertical_flip_image, tv_tensors.Image),
pytest.param(
F._geometry._vertical_flip_image_cvcuda,
cvcuda.Tensor,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
(F.vertical_flip_bounding_boxes, tv_tensors.BoundingBoxes),
(F.vertical_flip_mask, tv_tensors.Mask),
(F.vertical_flip_video, tv_tensors.Video),
Expand All @@ -1886,6 +1919,10 @@ def test_functional_signature(self, kernel, input_type):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand All @@ -1897,13 +1934,23 @@ def test_transform(self, make_input, device):
check_transform(transforms.RandomVerticalFlip(p=1), make_input(device=device))

@pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
def test_image_correctness(self, fn):
image = make_image(dtype=torch.uint8, device="cpu")

@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
],
)
def test_image_correctness(self, fn, make_input):
image = make_input()
actual = fn(image)
expected = F.to_image(F.vertical_flip(F.to_pil_image(image)))

torch.testing.assert_close(actual, expected)
if make_input is make_image_cvcuda:
image = F.cvcuda_to_tensor(image)[0].cpu()
expected = F.vertical_flip(F.to_pil_image(image))
assert_equal(actual, expected)

def _reference_vertical_flip_bounding_boxes(self, bounding_boxes: tv_tensors.BoundingBoxes):
affine_matrix = np.array(
Expand Down Expand Up @@ -1955,6 +2002,10 @@ def test_keypoints_correctness(self, fn):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand All @@ -1964,11 +2015,8 @@ def test_keypoints_correctness(self, fn):
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform_noop(self, make_input, device):
input = make_input(device=device)

transform = transforms.RandomVerticalFlip(p=0)

output = transform(input)

assert_equal(output, input)


Expand Down
17 changes: 16 additions & 1 deletion torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
from torchvision.ops.boxes import box_iou
from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
from torchvision.transforms.v2.functional._utils import _FillType
from torchvision.transforms.v2.functional._utils import (
_FillType,
_import_cvcuda,
_is_cvcuda_available,
is_cvcuda_tensor,
)

from ._transform import _RandomApplyTransform
from ._utils import (
Expand All @@ -30,6 +35,10 @@
query_size,
)

CVCUDA_AVAILABLE = _is_cvcuda_available()
if CVCUDA_AVAILABLE:
cvcuda = _import_cvcuda()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this import needed? I feel we are not using cvcuda module directly anywhere in this file?



class RandomHorizontalFlip(_RandomApplyTransform):
"""Horizontally flip the input with a given probability.
Expand All @@ -45,6 +54,9 @@ class RandomHorizontalFlip(_RandomApplyTransform):

_v1_transform_cls = _transforms.RandomHorizontalFlip

if CVCUDA_AVAILABLE:
_transformed_types = _RandomApplyTransform._transformed_types + (is_cvcuda_tensor,)

def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
return self._call_kernel(F.horizontal_flip, inpt)

Expand All @@ -63,6 +75,9 @@ class RandomVerticalFlip(_RandomApplyTransform):

_v1_transform_cls = _transforms.RandomVerticalFlip

if CVCUDA_AVAILABLE:
_transformed_types = _RandomApplyTransform._transformed_types + (is_cvcuda_tensor,)

def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
return self._call_kernel(F.vertical_flip, inpt)

Expand Down
31 changes: 29 additions & 2 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numbers
import warnings
from collections.abc import Sequence
from typing import Any, Optional, Union
from typing import Any, Optional, TYPE_CHECKING, Union

import PIL.Image
import torch
Expand All @@ -26,7 +26,18 @@

from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format

from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal
from ._utils import (
_FillTypeJIT,
_get_kernel,
_import_cvcuda,
_is_cvcuda_available,
_register_five_ten_crop_kernel_internal,
_register_kernel_internal,
)

CVCUDA_AVAILABLE = _is_cvcuda_available()
if TYPE_CHECKING:
import cvcuda


def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode:
Expand Down Expand Up @@ -62,6 +73,14 @@ def _horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
return _FP.hflip(image)


def _horizontal_flip_image_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor":

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a bit of a nitpick, but could we rename the function to _horizontal_flip_cvcuda, CV-CUDA only operates on one datatype so the extra "image" in the funcname does not add value IMO. Removing it also mirrors the cvcuda_to_tensor and tensor_to_cvcuda functions

Copy link
Member

@NicolasHug NicolasHug Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the cvcuda_to_tensor and tensor_to_cvcuda functions are a bit of outliers in that sense, but most other kernels specify the nature of the input they work on. We have e.g.

  • horizontal_flip_image for tensors and tv_tensor.Image
  • _horizontal_flip_image_pil
  • horizontal_flip_mask
  • horizontal_flip_bounding_boxes
  • etc.

The CVCUDA backend is basically of the same nature as the PIL backend. So It makes sense to keep it named _horizontal_flip_cvcuda (EDIT: meant _horizontal_flip_image_cvcuda!!) IMO., like we have _horizontal_flip_image_pil.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug just to be sure, you are saying it makes sense to keep it named _horizontal_flip_cvcuda, I guess you mean it makes sense to keep it named _horizontal_flip_image_cvcuda?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks for catching! I'll edit above to avoid further confusion

return _import_cvcuda().flip(image, flipCode=1)


if CVCUDA_AVAILABLE:
_register_kernel_internal(horizontal_flip, _import_cvcuda().Tensor)(_horizontal_flip_image_cvcuda)


@_register_kernel_internal(horizontal_flip, tv_tensors.Mask)
def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
return horizontal_flip_image(mask)
Expand Down Expand Up @@ -150,6 +169,14 @@ def _vertical_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
return _FP.vflip(image)


def _vertical_flip_image_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor":
return _import_cvcuda().flip(image, flipCode=0)


if CVCUDA_AVAILABLE:
_register_kernel_internal(vertical_flip, _import_cvcuda().Tensor)(_vertical_flip_image_cvcuda)


@_register_kernel_internal(vertical_flip, tv_tensors.Mask)
def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
return vertical_flip_image(mask)
Expand Down
8 changes: 8 additions & 0 deletions torchvision/transforms/v2/functional/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,11 @@ def _is_cvcuda_available():
return True
except ImportError:
return False


def is_cvcuda_tensor(inpt: Any) -> bool:
try:
cvcuda = _import_cvcuda()
return isinstance(inpt, cvcuda.Tensor)
except ImportError:
return False
Loading