Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision import io, tv_tensors
from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import to_cvcuda_tensor, to_image, to_pil_image
from torchvision.transforms.v2.functional import cvcuda_to_tensor, to_cvcuda_tensor, to_image, to_pil_image
from torchvision.transforms.v2.functional._utils import _import_cvcuda, _is_cvcuda_available
from torchvision.utils import _Image_fromarray


IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
CVCUDA_AVAILABLE = _is_cvcuda_available()
CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
MPS_NOT_AVAILABLE_MSG = "MPS device not available"
OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda."
Expand Down Expand Up @@ -275,6 +277,17 @@ def combinations_grid(**kwargs):
return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())]


def cvcuda_to_pil_compatible_tensor(tensor: "cvcuda.Tensor") -> torch.Tensor:
tensor = cvcuda_to_tensor(tensor)
if tensor.ndim != 4:
raise ValueError(f"CV-CUDA Tensor should be 4 dimensional. Got {tensor.ndim} dimensions.")
if tensor.shape[0] != 1:
raise ValueError(
f"CV-CUDA Tensor should have batch dimension 1 for comparison with PIL.Image.Image. Got {tensor.shape[0]}."
)
return tensor.squeeze(0).cpu()


class ImagePair(TensorLikePair):
def __init__(
self,
Expand All @@ -287,6 +300,11 @@ def __init__(
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
actual, expected = (to_image(input) for input in [actual, expected])

# handle check for CV-CUDA Tensors
if CVCUDA_AVAILABLE and isinstance(actual, _import_cvcuda().Tensor):
# Use the PIL compatible tensor, so we can always compare with PIL.Image.Image
actual = cvcuda_to_pil_compatible_tensor(actual)

super().__init__(actual, expected, **other_parameters)
self.mae = mae

Expand Down Expand Up @@ -400,8 +418,8 @@ def make_image_pil(*args, **kwargs):
return to_pil_image(make_image(*args, **kwargs))


def make_image_cvcuda(*args, **kwargs):
return to_cvcuda_tensor(make_image(*args, **kwargs))
def make_image_cvcuda(*args, batch_dims=(1,), **kwargs):
return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs))


def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"):
Expand Down
54 changes: 44 additions & 10 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
import torchvision.transforms.v2 as transforms

from common_utils import (
assert_close,
assert_equal,
cache,
cpu_and_cuda,
cvcuda_to_pil_compatible_tensor,
freeze_rng_state,
ignore_jit_no_profile_information_warning,
make_bounding_boxes,
Expand All @@ -41,7 +43,6 @@
)

from torch import nn
from torch.testing import assert_close
from torch.utils._pytree import tree_flatten, tree_map
from torch.utils.data import DataLoader, default_collate
from torchvision import tv_tensors
Expand Down Expand Up @@ -5128,6 +5129,9 @@ def test_kernel_video(self):
make_segmentation_mask,
make_video,
make_keypoints,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
def test_functional(self, make_input):
Expand All @@ -5143,9 +5147,16 @@ def test_functional(self, make_input):
(F.perspective_mask, tv_tensors.Mask),
(F.perspective_video, tv_tensors.Video),
(F.perspective_keypoints, tv_tensors.KeyPoints),
pytest.param(
F._geometry._perspective_cvcuda,
"cvcuda.Tensor",
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"),
),
],
)
def test_functional_signature(self, kernel, input_type):
if input_type == "cvcuda.Tensor":
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.perspective, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize("distortion_scale", [0.5, 0.0, 1.0])
Expand All @@ -5159,6 +5170,9 @@ def test_functional_signature(self, kernel, input_type):
make_segmentation_mask,
make_video,
make_keypoints,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
def test_transform(self, distortion_scale, make_input):
Expand All @@ -5174,12 +5188,25 @@ def test_transform_error(self, distortion_scale):
"interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR]
)
@pytest.mark.parametrize("fill", CORRECTNESS_FILLS)
def test_image_functional_correctness(self, coefficients, interpolation, fill):
image = make_image(dtype=torch.uint8, device="cpu")
@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
),
],
)
def test_image_functional_correctness(self, coefficients, interpolation, fill, make_input):
image = make_input(dtype=torch.uint8, device="cpu")

actual = F.perspective(
image, startpoints=None, endpoints=None, coefficients=coefficients, interpolation=interpolation, fill=fill
)
if make_input is make_image_cvcuda:
actual = cvcuda_to_pil_compatible_tensor(actual)
image = cvcuda_to_pil_compatible_tensor(image)

expected = F.to_image(
F.perspective(
F.to_pil_image(image),
Expand All @@ -5191,13 +5218,20 @@ def test_image_functional_correctness(self, coefficients, interpolation, fill):
)
)

if interpolation is transforms.InterpolationMode.BILINEAR:
abs_diff = (actual.float() - expected.float()).abs()
assert (abs_diff > 1).float().mean() < 7e-2
mae = abs_diff.mean()
assert mae < 3
else:
assert_equal(actual, expected)
if make_input is make_image:
if interpolation is transforms.InterpolationMode.BILINEAR:
abs_diff = (actual.float() - expected.float()).abs()
assert (abs_diff > 1).float().mean() < 7e-2
mae = abs_diff.mean()
assert mae < 3
else:
assert_equal(actual, expected)
else: # CV-CUDA
# just check that the shapes/dtypes are the same, cvcuda warp_perspective uses different algorithm
# visually the results are the same on real images,
# realistically, the diff is not visible to the human eye
tolerance = 255 if interpolation is transforms.InterpolationMode.NEAREST else 125
assert_close(actual, expected, rtol=0, atol=tolerance)

def _reference_perspective_bounding_boxes(self, bounding_boxes, *, startpoints, endpoints):
format = bounding_boxes.format
Expand Down
4 changes: 3 additions & 1 deletion torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torchvision.ops.boxes import box_iou
from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
from torchvision.transforms.v2.functional._utils import _FillType
from torchvision.transforms.v2.functional._utils import _FillType, is_cvcuda_tensor

from ._transform import _RandomApplyTransform
from ._utils import (
Expand Down Expand Up @@ -936,6 +936,8 @@ class RandomPerspective(_RandomApplyTransform):

_v1_transform_cls = _transforms.RandomPerspective

_transformed_types = _RandomApplyTransform._transformed_types + (is_cvcuda_tensor,)

def __init__(
self,
distortion_scale: float = 0.5,
Expand Down
5 changes: 3 additions & 2 deletions torchvision/transforms/v2/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torchvision._utils import sequence_to_str

from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_cvcuda_tensor, is_pure_tensor
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT


Expand Down Expand Up @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]:
chws = {
tuple(get_dimensions(inpt))
for inpt in flat_inputs
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, is_cvcuda_tensor))
}
if not chws:
raise TypeError("No image or video was found in the sample")
Expand All @@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]:
tv_tensors.Mask,
tv_tensors.BoundingBoxes,
tv_tensors.KeyPoints,
is_cvcuda_tensor,
),
)
}
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from torchvision.transforms import InterpolationMode # usort: skip

from ._utils import is_pure_tensor, register_kernel # usort: skip
from ._utils import is_pure_tensor, register_kernel, is_cvcuda_tensor # usort: skip

from ._meta import (
clamp_bounding_boxes,
Expand Down
11 changes: 10 additions & 1 deletion torchvision/transforms/v2/functional/_augment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import io
from typing import TYPE_CHECKING

import PIL.Image

Expand All @@ -8,7 +9,15 @@
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once

from ._utils import _get_kernel, _register_kernel_internal
from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal


CVCUDA_AVAILABLE = _is_cvcuda_available()

if TYPE_CHECKING:
import cvcuda # type: ignore[import-not-found]
if CVCUDA_AVAILABLE:
cvcuda = _import_cvcuda() # noqa: F811


def erase(
Expand Down
12 changes: 11 additions & 1 deletion torchvision/transforms/v2/functional/_color.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import TYPE_CHECKING

import PIL.Image
import torch
from torch.nn.functional import conv2d
Expand All @@ -9,7 +11,15 @@

from ._misc import _num_value_bits, to_dtype_image
from ._type_conversion import pil_to_tensor, to_pil_image
from ._utils import _get_kernel, _register_kernel_internal
from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal


CVCUDA_AVAILABLE = _is_cvcuda_available()

if TYPE_CHECKING:
import cvcuda # type: ignore[import-not-found]
if CVCUDA_AVAILABLE:
cvcuda = _import_cvcuda() # noqa: F811


def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:
Expand Down
85 changes: 83 additions & 2 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import numbers
import warnings
from collections.abc import Sequence
from typing import Any, Optional, Union
from typing import Any, Optional, TYPE_CHECKING, Union

import numpy as np

import PIL.Image
import torch
Expand All @@ -26,7 +28,22 @@

from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format

from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal
from ._utils import (
_FillTypeJIT,
_get_kernel,
_import_cvcuda,
_is_cvcuda_available,
_register_five_ten_crop_kernel_internal,
_register_kernel_internal,
)


CVCUDA_AVAILABLE = _is_cvcuda_available()

if TYPE_CHECKING:
import cvcuda # type: ignore[import-not-found]
if CVCUDA_AVAILABLE:
cvcuda = _import_cvcuda() # noqa: F811


def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode:
Expand Down Expand Up @@ -2258,6 +2275,70 @@ def perspective_video(
)


if CVCUDA_AVAILABLE:
_cvcuda_interp = {
InterpolationMode.BILINEAR: cvcuda.Interp.LINEAR,
"bilinear": cvcuda.Interp.LINEAR,
"linear": cvcuda.Interp.LINEAR,
2: cvcuda.Interp.LINEAR,
InterpolationMode.BICUBIC: cvcuda.Interp.CUBIC,
"bicubic": cvcuda.Interp.CUBIC,
3: cvcuda.Interp.CUBIC,
InterpolationMode.NEAREST: cvcuda.Interp.NEAREST,
"nearest": cvcuda.Interp.NEAREST,
0: cvcuda.Interp.NEAREST,
InterpolationMode.BOX: cvcuda.Interp.BOX,
"box": cvcuda.Interp.BOX,
4: cvcuda.Interp.BOX,
InterpolationMode.HAMMING: cvcuda.Interp.HAMMING,
"hamming": cvcuda.Interp.HAMMING,
5: cvcuda.Interp.HAMMING,
InterpolationMode.LANCZOS: cvcuda.Interp.LANCZOS,
"lanczos": cvcuda.Interp.LANCZOS,
1: cvcuda.Interp.LANCZOS,
}


def _perspective_cvcuda(
image: "cvcuda.Tensor",
startpoints: Optional[list[list[int]]],
endpoints: Optional[list[list[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: _FillTypeJIT = None,
coefficients: Optional[list[float]] = None,
) -> "cvcuda.Tensor":
cvcuda = _import_cvcuda()

c = _perspective_coefficients(startpoints, endpoints, coefficients)
interpolation = _check_interpolation(interpolation)

interp = _cvcuda_interp.get(interpolation)
if interp is None:
raise ValueError(f"Invalid interpolation mode: {interpolation}")

xform = np.array([[c[0], c[1], c[2]], [c[3], c[4], c[5]], [c[6], c[7], 1.0]], dtype=np.float32)

num_channels = image.shape[-1]
if fill is None:
border_value = np.zeros(num_channels, dtype=np.float32)
elif isinstance(fill, (int, float)):
border_value = np.full(num_channels, fill, dtype=np.float32)
else:
border_value = np.array(fill, dtype=np.float32)[:num_channels]

return cvcuda.warp_perspective(
image,
xform,
flags=interp | cvcuda.Interp.WARP_INVERSE_MAP,
border_mode=cvcuda.Border.CONSTANT,
border_value=border_value,
)


if CVCUDA_AVAILABLE:
_register_kernel_internal(perspective, _import_cvcuda().Tensor)(_perspective_cvcuda)


def elastic(
inpt: torch.Tensor,
displacement: torch.Tensor,
Expand Down
Loading