diff --git a/test/common_utils.py b/test/common_utils.py index 8c3c9dd58a8..3b889e93d2e 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -20,13 +20,15 @@ from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torchvision import io, tv_tensors from torchvision.transforms._functional_tensor import _max_value as get_max_value -from torchvision.transforms.v2.functional import to_cvcuda_tensor, to_image, to_pil_image +from torchvision.transforms.v2.functional import cvcuda_to_tensor, to_cvcuda_tensor, to_image, to_pil_image +from torchvision.transforms.v2.functional._utils import _import_cvcuda, _is_cvcuda_available from torchvision.utils import _Image_fromarray IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"]) IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" +CVCUDA_AVAILABLE = _is_cvcuda_available() CUDA_NOT_AVAILABLE_MSG = "CUDA device not available" MPS_NOT_AVAILABLE_MSG = "MPS device not available" OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda." @@ -275,6 +277,17 @@ def combinations_grid(**kwargs): return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())] +def cvcuda_to_pil_compatible_tensor(tensor: "cvcuda.Tensor") -> torch.Tensor: + tensor = cvcuda_to_tensor(tensor) + if tensor.ndim != 4: + raise ValueError(f"CV-CUDA Tensor should be 4 dimensional. Got {tensor.ndim} dimensions.") + if tensor.shape[0] != 1: + raise ValueError( + f"CV-CUDA Tensor should have batch dimension 1 for comparison with PIL.Image.Image. Got {tensor.shape[0]}." + ) + return tensor.squeeze(0).cpu() + + class ImagePair(TensorLikePair): def __init__( self, @@ -287,6 +300,11 @@ def __init__( if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]): actual, expected = (to_image(input) for input in [actual, expected]) + # handle check for CV-CUDA Tensors + if CVCUDA_AVAILABLE and isinstance(actual, _import_cvcuda().Tensor): + # Use the PIL compatible tensor, so we can always compare with PIL.Image.Image + actual = cvcuda_to_pil_compatible_tensor(actual) + super().__init__(actual, expected, **other_parameters) self.mae = mae @@ -400,8 +418,8 @@ def make_image_pil(*args, **kwargs): return to_pil_image(make_image(*args, **kwargs)) -def make_image_cvcuda(*args, **kwargs): - return to_cvcuda_tensor(make_image(*args, **kwargs)) +def make_image_cvcuda(*args, batch_dims=(1,), **kwargs): + return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs)) def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"): diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 670a9d00ffb..63b17788114 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -21,9 +21,11 @@ import torchvision.transforms.v2 as transforms from common_utils import ( + assert_close, assert_equal, cache, cpu_and_cuda, + cvcuda_to_pil_compatible_tensor, freeze_rng_state, ignore_jit_no_profile_information_warning, make_bounding_boxes, @@ -41,7 +43,6 @@ ) from torch import nn -from torch.testing import assert_close from torch.utils._pytree import tree_flatten, tree_map from torch.utils.data import DataLoader, default_collate from torchvision import tv_tensors @@ -5128,6 +5129,9 @@ def test_kernel_video(self): make_segmentation_mask, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), ], ) def test_functional(self, make_input): @@ -5143,9 +5147,16 @@ def test_functional(self, make_input): (F.perspective_mask, tv_tensors.Mask), (F.perspective_video, tv_tensors.Video), (F.perspective_keypoints, tv_tensors.KeyPoints), + pytest.param( + F._geometry._perspective_cvcuda, + "cvcuda.Tensor", + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"), + ), ], ) def test_functional_signature(self, kernel, input_type): + if input_type == "cvcuda.Tensor": + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.perspective, kernel=kernel, input_type=input_type) @pytest.mark.parametrize("distortion_scale", [0.5, 0.0, 1.0]) @@ -5159,6 +5170,9 @@ def test_functional_signature(self, kernel, input_type): make_segmentation_mask, make_video, make_keypoints, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), ], ) def test_transform(self, distortion_scale, make_input): @@ -5174,12 +5188,25 @@ def test_transform_error(self, distortion_scale): "interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR] ) @pytest.mark.parametrize("fill", CORRECTNESS_FILLS) - def test_image_functional_correctness(self, coefficients, interpolation, fill): - image = make_image(dtype=torch.uint8, device="cpu") + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available") + ), + ], + ) + def test_image_functional_correctness(self, coefficients, interpolation, fill, make_input): + image = make_input(dtype=torch.uint8, device="cpu") actual = F.perspective( image, startpoints=None, endpoints=None, coefficients=coefficients, interpolation=interpolation, fill=fill ) + if make_input is make_image_cvcuda: + actual = cvcuda_to_pil_compatible_tensor(actual) + image = cvcuda_to_pil_compatible_tensor(image) + expected = F.to_image( F.perspective( F.to_pil_image(image), @@ -5191,13 +5218,20 @@ def test_image_functional_correctness(self, coefficients, interpolation, fill): ) ) - if interpolation is transforms.InterpolationMode.BILINEAR: - abs_diff = (actual.float() - expected.float()).abs() - assert (abs_diff > 1).float().mean() < 7e-2 - mae = abs_diff.mean() - assert mae < 3 - else: - assert_equal(actual, expected) + if make_input is make_image: + if interpolation is transforms.InterpolationMode.BILINEAR: + abs_diff = (actual.float() - expected.float()).abs() + assert (abs_diff > 1).float().mean() < 7e-2 + mae = abs_diff.mean() + assert mae < 3 + else: + assert_equal(actual, expected) + else: # CV-CUDA + # just check that the shapes/dtypes are the same, cvcuda warp_perspective uses different algorithm + # visually the results are the same on real images, + # realistically, the diff is not visible to the human eye + tolerance = 255 if interpolation is transforms.InterpolationMode.NEAREST else 125 + assert_close(actual, expected, rtol=0, atol=tolerance) def _reference_perspective_bounding_boxes(self, bounding_boxes, *, startpoints, endpoints): format = bounding_boxes.format diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 1418a6b4953..ba662744656 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -11,7 +11,7 @@ from torchvision.ops.boxes import box_iou from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform -from torchvision.transforms.v2.functional._utils import _FillType +from torchvision.transforms.v2.functional._utils import _FillType, is_cvcuda_tensor from ._transform import _RandomApplyTransform from ._utils import ( @@ -936,6 +936,8 @@ class RandomPerspective(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomPerspective + _transformed_types = _RandomApplyTransform._transformed_types + (is_cvcuda_tensor,) + def __init__( self, distortion_scale: float = 0.5, diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index bb6051b4e61..3fc33ce5964 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -15,7 +15,7 @@ from torchvision._utils import sequence_to_str from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 -from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor +from torchvision.transforms.v2.functional import get_dimensions, get_size, is_cvcuda_tensor, is_pure_tensor from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]: chws = { tuple(get_dimensions(inpt)) for inpt in flat_inputs - if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video)) + if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, is_cvcuda_tensor)) } if not chws: raise TypeError("No image or video was found in the sample") @@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]: tv_tensors.Mask, tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, + is_cvcuda_tensor, ), ) } diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 032a993b1f0..52181e4624b 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -1,6 +1,6 @@ from torchvision.transforms import InterpolationMode # usort: skip -from ._utils import is_pure_tensor, register_kernel # usort: skip +from ._utils import is_pure_tensor, register_kernel, is_cvcuda_tensor # usort: skip from ._meta import ( clamp_bounding_boxes, diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index a904d8d7cbd..7ce5bdc7b7e 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -1,4 +1,5 @@ import io +from typing import TYPE_CHECKING import PIL.Image @@ -8,7 +9,15 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _register_kernel_internal +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal + + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def erase( diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index be254c0d63a..5be9c62902a 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -1,3 +1,5 @@ +from typing import TYPE_CHECKING + import PIL.Image import torch from torch.nn.functional import conv2d @@ -9,7 +11,15 @@ from ._misc import _num_value_bits, to_dtype_image from ._type_conversion import pil_to_tensor, to_pil_image -from ._utils import _get_kernel, _register_kernel_internal +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal + + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 4fcb7fabe0d..34a08fcb8f2 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -2,7 +2,9 @@ import numbers import warnings from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, Optional, TYPE_CHECKING, Union + +import numpy as np import PIL.Image import torch @@ -26,7 +28,22 @@ from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format -from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal +from ._utils import ( + _FillTypeJIT, + _get_kernel, + _import_cvcuda, + _is_cvcuda_available, + _register_five_ten_crop_kernel_internal, + _register_kernel_internal, +) + + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: @@ -2258,6 +2275,70 @@ def perspective_video( ) +if CVCUDA_AVAILABLE: + _cvcuda_interp = { + InterpolationMode.BILINEAR: cvcuda.Interp.LINEAR, + "bilinear": cvcuda.Interp.LINEAR, + "linear": cvcuda.Interp.LINEAR, + 2: cvcuda.Interp.LINEAR, + InterpolationMode.BICUBIC: cvcuda.Interp.CUBIC, + "bicubic": cvcuda.Interp.CUBIC, + 3: cvcuda.Interp.CUBIC, + InterpolationMode.NEAREST: cvcuda.Interp.NEAREST, + "nearest": cvcuda.Interp.NEAREST, + 0: cvcuda.Interp.NEAREST, + InterpolationMode.BOX: cvcuda.Interp.BOX, + "box": cvcuda.Interp.BOX, + 4: cvcuda.Interp.BOX, + InterpolationMode.HAMMING: cvcuda.Interp.HAMMING, + "hamming": cvcuda.Interp.HAMMING, + 5: cvcuda.Interp.HAMMING, + InterpolationMode.LANCZOS: cvcuda.Interp.LANCZOS, + "lanczos": cvcuda.Interp.LANCZOS, + 1: cvcuda.Interp.LANCZOS, + } + + +def _perspective_cvcuda( + image: "cvcuda.Tensor", + startpoints: Optional[list[list[int]]], + endpoints: Optional[list[list[int]]], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + fill: _FillTypeJIT = None, + coefficients: Optional[list[float]] = None, +) -> "cvcuda.Tensor": + cvcuda = _import_cvcuda() + + c = _perspective_coefficients(startpoints, endpoints, coefficients) + interpolation = _check_interpolation(interpolation) + + interp = _cvcuda_interp.get(interpolation) + if interp is None: + raise ValueError(f"Invalid interpolation mode: {interpolation}") + + xform = np.array([[c[0], c[1], c[2]], [c[3], c[4], c[5]], [c[6], c[7], 1.0]], dtype=np.float32) + + num_channels = image.shape[-1] + if fill is None: + border_value = np.zeros(num_channels, dtype=np.float32) + elif isinstance(fill, (int, float)): + border_value = np.full(num_channels, fill, dtype=np.float32) + else: + border_value = np.array(fill, dtype=np.float32)[:num_channels] + + return cvcuda.warp_perspective( + image, + xform, + flags=interp | cvcuda.Interp.WARP_INVERSE_MAP, + border_mode=cvcuda.Border.CONSTANT, + border_value=border_value, + ) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(perspective, _import_cvcuda().Tensor)(_perspective_cvcuda) + + def elastic( inpt: torch.Tensor, displacement: torch.Tensor, diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 6b8f19f12f4..ee562cb2aee 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -51,6 +51,16 @@ def get_dimensions_video(video: torch.Tensor) -> list[int]: return get_dimensions_image(video) +def _get_dimensions_cvcuda(image: "cvcuda.Tensor") -> list[int]: + # CV-CUDA tensor is always in NHWC layout + # get_dimensions is CHW + return [image.shape[3], image.shape[1], image.shape[2]] + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(get_dimensions, cvcuda.Tensor)(_get_dimensions_cvcuda) + + def get_num_channels(inpt: torch.Tensor) -> int: if torch.jit.is_scripting(): return get_num_channels_image(inpt) @@ -87,6 +97,16 @@ def get_num_channels_video(video: torch.Tensor) -> int: get_image_num_channels = get_num_channels +def _get_num_channels_cvcuda(image: "cvcuda.Tensor") -> int: + # CV-CUDA tensor is always in NHWC layout + # get_num_channels is C + return image.shape[3] + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(get_num_channels, cvcuda.Tensor)(_get_num_channels_cvcuda) + + def get_size(inpt: torch.Tensor) -> list[int]: if torch.jit.is_scripting(): return get_size_image(inpt) @@ -114,7 +134,7 @@ def _get_size_image_pil(image: PIL.Image.Image) -> list[int]: return [height, width] -def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: +def _get_size_cvcuda(image: "cvcuda.Tensor") -> list[int]: """Get size of `cvcuda.Tensor` with NHWC layout.""" hw = list(image.shape[-3:-1]) ndims = len(hw) @@ -125,7 +145,7 @@ def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: if CVCUDA_AVAILABLE: - _get_size_image_cvcuda = _register_kernel_internal(get_size, cvcuda.Tensor)(get_size_image_cvcuda) + _register_kernel_internal(get_size, cvcuda.Tensor)(_get_size_cvcuda) @_register_kernel_internal(get_size, tv_tensors.Video, tv_tensor_wrapper=False) diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index daf263df046..0fa05a2113c 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -1,5 +1,5 @@ import math -from typing import Optional +from typing import Optional, TYPE_CHECKING import PIL.Image import torch @@ -13,7 +13,14 @@ from ._meta import _convert_bounding_box_format -from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal, is_pure_tensor + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def normalize( diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index ad1eddd258b..44b2edeaf2d 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -169,3 +169,20 @@ def _is_cvcuda_available(): return True except ImportError: return False + + +def is_cvcuda_tensor(inpt: Any) -> bool: + """ + Check if the input is a CVCUDA tensor. + + Args: + inpt: The input to check. + + Returns: + True if the input is a CV-CUDA tensor, False otherwise. + """ + try: + cvcuda = _import_cvcuda() + return isinstance(inpt, cvcuda.Tensor) + except ImportError: + return False