Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
44db71c
implement additional cvcuda infra for all branches to avoid duplicate…
justincdavis Nov 25, 2025
e3dd700
update make_image_cvcuda to have default batch dim
justincdavis Nov 25, 2025
c035df1
add stanardized setup to main for easier updating of PRs and branches
justincdavis Dec 2, 2025
98d7dfb
update is_cvcuda_tensor
justincdavis Dec 2, 2025
ddc116d
add cvcuda to pil compatible to transforms by default
justincdavis Dec 2, 2025
e51dc7e
remove cvcuda from transform class
justincdavis Dec 2, 2025
e14e210
merge with main
justincdavis Dec 4, 2025
4939355
resolve more formatting naming
justincdavis Dec 4, 2025
1e864d8
initial cvcuda normalize kernel implementation
justincdavis Nov 17, 2025
01efae7
add comment explaining mean/std behavior, one-line intermediate creation
justincdavis Nov 17, 2025
79ea0da
fix: normalize_cvcuda move to correct patterns for tests/exporting
justincdavis Nov 20, 2025
429f77f
fix tests crashing before run without cvcuda
justincdavis Nov 24, 2025
8ed3b26
resolve more review comments
justincdavis Nov 24, 2025
57ca083
remove extra parameterize for dtype
justincdavis Nov 24, 2025
184e379
simplify normalize testing into single test parameterize on input cre…
justincdavis Nov 26, 2025
995834a
update normalize based on PR reviews
justincdavis Dec 2, 2025
7105358
update normalize with changes from main
justincdavis Dec 4, 2025
0f8910e
remove extra cvcuda_available add
justincdavis Dec 4, 2025
969dd3f
check input type on kernel for signature test
justincdavis Dec 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 62 additions & 16 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torchvision.transforms.v2 as transforms

from common_utils import (
assert_close,
assert_equal,
cache,
cpu_and_cuda,
Expand All @@ -41,7 +42,6 @@
)

from torch import nn
from torch.testing import assert_close
from torch.utils._pytree import tree_flatten, tree_map
from torch.utils.data import DataLoader, default_collate
from torchvision import tv_tensors
Expand Down Expand Up @@ -5552,24 +5552,34 @@ def test_kernel_image(self, mean, std, device):

@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_image_inplace(self, device):
input = make_image_tensor(dtype=torch.float32, device=device)
input_version = input._version
inpt = make_image_tensor(dtype=torch.float32, device=device)
input_version = inpt._version

output_out_of_place = F.normalize_image(input, mean=self.MEAN, std=self.STD)
assert output_out_of_place.data_ptr() != input.data_ptr()
assert output_out_of_place is not input
output_out_of_place = F.normalize_image(inpt, mean=self.MEAN, std=self.STD)
assert output_out_of_place.data_ptr() != inpt.data_ptr()
assert output_out_of_place is not inpt

output_inplace = F.normalize_image(input, mean=self.MEAN, std=self.STD, inplace=True)
assert output_inplace.data_ptr() == input.data_ptr()
output_inplace = F.normalize_image(inpt, mean=self.MEAN, std=self.STD, inplace=True)
assert output_inplace.data_ptr() == inpt.data_ptr()
assert output_inplace._version > input_version
assert output_inplace is input
assert output_inplace is inpt

assert_equal(output_inplace, output_out_of_place)

def test_kernel_video(self):
check_kernel(F.normalize_video, make_video(dtype=torch.float32), mean=self.MEAN, std=self.STD)

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video])
@pytest.mark.parametrize(
"make_input",
[
make_image_tensor,
make_image,
make_video,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
def test_functional(self, make_input):
check_functional(F.normalize, make_input(dtype=torch.float32), mean=self.MEAN, std=self.STD)

Expand All @@ -5579,9 +5589,16 @@ def test_functional(self, make_input):
(F.normalize_image, torch.Tensor),
(F.normalize_image, tv_tensors.Image),
(F.normalize_video, tv_tensors.Video),
pytest.param(
F._misc._normalize_image_cvcuda,
None,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"),
),
],
)
def test_functional_signature(self, kernel, input_type):
if kernel is F._misc._normalize_image_cvcuda:
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.normalize, kernel=kernel, input_type=input_type)

def test_functional_error(self):
Expand All @@ -5595,9 +5612,9 @@ def test_functional_error(self):
with pytest.raises(ValueError, match="std evaluated to zero, leading to division by zero"):
F.normalize_image(make_image(dtype=torch.float32), mean=self.MEAN, std=std)

def _sample_input_adapter(self, transform, input, device):
def _sample_input_adapter(self, transform, inpt, device):
adapted_input = {}
for key, value in input.items():
for key, value in inpt.items():
if isinstance(value, PIL.Image.Image):
# normalize doesn't support PIL images
continue
Expand All @@ -5607,7 +5624,17 @@ def _sample_input_adapter(self, transform, input, device):
adapted_input[key] = value
return adapted_input

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video])
@pytest.mark.parametrize(
"make_input",
[
make_image_tensor,
make_image,
make_video,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
def test_transform(self, make_input):
check_transform(
transforms.Normalize(mean=self.MEAN, std=self.STD),
Expand All @@ -5622,14 +5649,33 @@ def _reference_normalize_image(self, image, *, mean, std):

@pytest.mark.parametrize(("mean", "std"), MEANS_STDS)
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.float64])
@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
@pytest.mark.parametrize("fn", [F.normalize, transform_cls_to_functional(transforms.Normalize)])
def test_correctness_image(self, mean, std, dtype, fn):
image = make_image(dtype=dtype)
def test_correctness_image(self, mean, std, dtype, make_input, fn):
if make_input == make_image_cvcuda and dtype != torch.float32:
pytest.skip("CVCUDA only supports float32 for normalize")

image = make_input(dtype=dtype)

actual = fn(image, mean=mean, std=std)

if make_input == make_image_cvcuda:
image = F.cvcuda_to_tensor(image)[0].cpu()

expected = self._reference_normalize_image(image, mean=mean, std=std)

assert_equal(actual, expected)
if make_input == make_image_cvcuda:
assert_close(actual, expected, rtol=0, atol=1e-6)
else:
assert_equal(actual, expected)


class TestClampBoundingBoxes:
Expand Down
7 changes: 7 additions & 0 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from torchvision import transforms as _transforms, tv_tensors
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2.functional._utils import _is_cvcuda_available, _is_cvcuda_tensor

from ._utils import (
_parse_labels_getter,
Expand All @@ -21,6 +22,9 @@
)


CVCUDA_AVAILABLE = _is_cvcuda_available()


# TODO: do we want/need to expose this?
class Identity(Transform):
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
Expand Down Expand Up @@ -160,6 +164,9 @@ class Normalize(Transform):

_v1_transform_cls = _transforms.Normalize

if CVCUDA_AVAILABLE:
_transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,)

def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False):
super().__init__()
self.mean = list(mean)
Expand Down
5 changes: 3 additions & 2 deletions torchvision/transforms/v2/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT, _is_cvcuda_tensor


def _setup_number_or_seq(arg: int | float | Sequence[int | float], name: str) -> Sequence[float]:
Expand Down Expand Up @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]:
chws = {
tuple(get_dimensions(inpt))
for inpt in flat_inputs
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, _is_cvcuda_tensor))
}
if not chws:
raise TypeError("No image or video was found in the sample")
Expand All @@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]:
tv_tensors.Mask,
tv_tensors.BoundingBoxes,
tv_tensors.KeyPoints,
_is_cvcuda_tensor,
),
)
}
Expand Down
22 changes: 21 additions & 1 deletion torchvision/transforms/v2/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,16 @@ def get_dimensions_video(video: torch.Tensor) -> list[int]:
return get_dimensions_image(video)


def get_dimensions_image_cvcuda(image: "cvcuda.Tensor") -> list[int]:
# CV-CUDA tensor is always in NHWC layout
# get_dimensions is CHW
return [image.shape[3], image.shape[1], image.shape[2]]


if CVCUDA_AVAILABLE:
_register_kernel_internal(get_dimensions, cvcuda.Tensor)(get_dimensions_image_cvcuda)


def get_num_channels(inpt: torch.Tensor) -> int:
if torch.jit.is_scripting():
return get_num_channels_image(inpt)
Expand Down Expand Up @@ -87,6 +97,16 @@ def get_num_channels_video(video: torch.Tensor) -> int:
get_image_num_channels = get_num_channels


def get_num_channels_image_cvcuda(image: "cvcuda.Tensor") -> int:
# CV-CUDA tensor is always in NHWC layout
# get_num_channels is C
return image.shape[3]


if CVCUDA_AVAILABLE:
_register_kernel_internal(get_num_channels, cvcuda.Tensor)(get_num_channels_image_cvcuda)


def get_size(inpt: torch.Tensor) -> list[int]:
if torch.jit.is_scripting():
return get_size_image(inpt)
Expand Down Expand Up @@ -125,7 +145,7 @@ def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]:


if CVCUDA_AVAILABLE:
_get_size_image_cvcuda = _register_kernel_internal(get_size, cvcuda.Tensor)(get_size_image_cvcuda)
_register_kernel_internal(get_size, _import_cvcuda().Tensor)(get_size_image_cvcuda)


@_register_kernel_internal(get_size, tv_tensors.Video, tv_tensor_wrapper=False)
Expand Down
49 changes: 47 additions & 2 deletions torchvision/transforms/v2/functional/_misc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Optional
from typing import Optional, TYPE_CHECKING

import PIL.Image
import torch
Expand All @@ -13,7 +13,14 @@

from ._meta import _convert_bounding_box_format

from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor
from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal, is_pure_tensor

CVCUDA_AVAILABLE = _is_cvcuda_available()

if TYPE_CHECKING:
import cvcuda # type: ignore[import-not-found]
if CVCUDA_AVAILABLE:
cvcuda = _import_cvcuda() # noqa: F811


def normalize(
Expand Down Expand Up @@ -72,6 +79,44 @@ def normalize_video(video: torch.Tensor, mean: list[float], std: list[float], in
return normalize_image(video, mean, std, inplace=inplace)


def _normalize_image_cvcuda(
image: "cvcuda.Tensor",
mean: list[float],
std: list[float],
inplace: bool = False,
) -> "cvcuda.Tensor":
cvcuda = _import_cvcuda()
if inplace:
raise ValueError("Inplace normalization is not supported for CVCUDA.")

# CV-CUDA supports signed int and float tensors
# torchvision only supports uint and float, right now CV-CUDA doesnt expose float16, so only check 32
# in the future add float16 once exposed in CV-CUDA
if not (image.dtype == cvcuda.Type.F32):
raise ValueError(f"Input tensor should be a float tensor. Got {image.dtype}.")

channels = image.shape[3]
if isinstance(mean, float | int):
mean = [mean] * channels
elif len(mean) != channels:
raise ValueError(f"Mean should have {channels} elements. Got {len(mean)}.")
if isinstance(std, float | int):
std = [std] * channels
elif len(std) != channels:
raise ValueError(f"Std should have {channels} elements. Got {len(std)}.")

mt = torch.as_tensor(mean, dtype=torch.float32).reshape(1, 1, 1, channels).cuda()
st = torch.as_tensor(std, dtype=torch.float32).reshape(1, 1, 1, channels).cuda()
mean_cv = cvcuda.as_tensor(mt, cvcuda.TensorLayout.NHWC)
std_cv = cvcuda.as_tensor(st, cvcuda.TensorLayout.NHWC)

return cvcuda.normalize(image, base=mean_cv, scale=std_cv, flags=cvcuda.NormalizeFlags.SCALE_IS_STDDEV)


if CVCUDA_AVAILABLE:
_register_kernel_internal(normalize, _import_cvcuda().Tensor)(_normalize_image_cvcuda)


def gaussian_blur(inpt: torch.Tensor, kernel_size: list[int], sigma: Optional[list[float]] = None) -> torch.Tensor:
"""See :class:`~torchvision.transforms.v2.GaussianBlur` for details."""
if torch.jit.is_scripting():
Expand Down