-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Resize CV-CUDA Backend #9302
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Resize CV-CUDA Backend #9302
Changes from all commits
44db71c
e3dd700
c035df1
98d7dfb
ddc116d
e51dc7e
e14e210
4939355
fbea584
0a7886c
aa38855
4f2752a
5fa8d4e
a895bae
88d5c39
ca6e5f4
9915aa9
88f00b9
aaaef0e
1f7b897
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ | |
| import torchvision.transforms.v2 as transforms | ||
|
|
||
| from common_utils import ( | ||
| assert_close, | ||
| assert_equal, | ||
| cache, | ||
| cpu_and_cuda, | ||
|
|
@@ -42,7 +43,6 @@ | |
| ) | ||
|
|
||
| from torch import nn | ||
| from torch.testing import assert_close | ||
| from torch.utils._pytree import tree_flatten, tree_map | ||
| from torch.utils.data import DataLoader, default_collate | ||
| from torchvision import tv_tensors | ||
|
|
@@ -804,6 +804,10 @@ def test_kernel_video(self): | |
| make_segmentation_mask, | ||
| make_video, | ||
| make_keypoints, | ||
| pytest.param( | ||
| make_image_cvcuda, | ||
| marks=pytest.mark.needs_cvcuda, | ||
| ), | ||
| ], | ||
| ) | ||
| def test_functional(self, size, make_input): | ||
|
|
@@ -827,9 +831,16 @@ def test_functional(self, size, make_input): | |
| (F.resize_mask, tv_tensors.Mask), | ||
| (F.resize_video, tv_tensors.Video), | ||
| (F.resize_keypoints, tv_tensors.KeyPoints), | ||
| pytest.param( | ||
| F._geometry._resize_image_cvcuda, | ||
| None, | ||
| marks=pytest.mark.needs_cvcuda, | ||
| ), | ||
| ], | ||
| ) | ||
| def test_functional_signature(self, kernel, input_type): | ||
| if kernel is F._geometry._resize_image_cvcuda: | ||
| input_type = _import_cvcuda().Tensor | ||
| check_functional_kernel_signature_match(F.resize, kernel=kernel, input_type=input_type) | ||
|
|
||
| @pytest.mark.parametrize("size", OUTPUT_SIZES) | ||
|
|
@@ -845,6 +856,10 @@ def test_functional_signature(self, kernel, input_type): | |
| make_detection_masks, | ||
| make_video, | ||
| make_keypoints, | ||
| pytest.param( | ||
| make_image_cvcuda, | ||
| marks=pytest.mark.needs_cvcuda, | ||
| ), | ||
| ], | ||
| ) | ||
| def test_transform(self, size, device, make_input): | ||
|
|
@@ -862,23 +877,77 @@ def _check_output_size(self, input, output, *, size, max_size): | |
| input_size=F.get_size(input), size=size, max_size=max_size | ||
| ) | ||
|
|
||
| @pytest.mark.parametrize( | ||
| "make_input", | ||
| [ | ||
| make_image, | ||
| pytest.param( | ||
| make_image_cvcuda, | ||
| marks=pytest.mark.needs_cvcuda, | ||
| ), | ||
| ], | ||
| ) | ||
| @pytest.mark.parametrize("size", OUTPUT_SIZES) | ||
| # `InterpolationMode.NEAREST` is modeled after the buggy `INTER_NEAREST` interpolation of CV2. | ||
| # The PIL equivalent of `InterpolationMode.NEAREST` is `InterpolationMode.NEAREST_EXACT` | ||
| @pytest.mark.parametrize("interpolation", set(INTERPOLATION_MODES) - {transforms.InterpolationMode.NEAREST}) | ||
| @pytest.mark.parametrize("use_max_size", [True, False]) | ||
| @pytest.mark.parametrize("fn", [F.resize, transform_cls_to_functional(transforms.Resize)]) | ||
| def test_image_correctness(self, size, interpolation, use_max_size, fn): | ||
| def test_image_correctness(self, make_input, size, interpolation, use_max_size, fn): | ||
| if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)): | ||
| return | ||
|
|
||
| image = make_image(self.INPUT_SIZE, dtype=torch.uint8) | ||
| image = make_input(self.INPUT_SIZE, dtype=torch.uint8) | ||
|
|
||
| actual = fn(image, size=size, interpolation=interpolation, **max_size_kwarg, antialias=True) | ||
|
|
||
| if make_input is make_image_cvcuda: | ||
| image = F.cvcuda_to_tensor(image)[0].cpu() | ||
|
|
||
| expected = F.to_image(F.resize(F.to_pil_image(image), size=size, interpolation=interpolation, **max_size_kwarg)) | ||
|
|
||
| self._check_output_size(image, actual, size=size, **max_size_kwarg) | ||
| torch.testing.assert_close(actual, expected, atol=1, rtol=0) | ||
|
|
||
| atol = 1 | ||
| # when using antialias, CV-CUDA is different for BICUBIC and BILINEAR, since antialias requires hq_resize | ||
| # hq_resize using interpolation will have differences on the edge boundaries | ||
| # no noticable visual difference | ||
| if make_input is make_image_cvcuda and ( | ||
| interpolation is transforms.InterpolationMode.BILINEAR | ||
| or interpolation is transforms.InterpolationMode.BICUBIC | ||
| ): | ||
| atol = 9 | ||
| assert_close(actual, expected, atol=atol, rtol=0) | ||
|
|
||
| @needs_cvcuda | ||
| @pytest.mark.parametrize("size", OUTPUT_SIZES) | ||
| @pytest.mark.parametrize("interpolation", set(INTERPOLATION_MODES) - {transforms.InterpolationMode.NEAREST}) | ||
| @pytest.mark.parametrize("use_max_size", [True, False]) | ||
| @pytest.mark.parametrize("antialias", [True, False]) | ||
| @pytest.mark.parametrize("fn", [F.resize, transform_cls_to_functional(transforms.Resize)]) | ||
| def test_image_correctness_cvcuda(self, size, interpolation, use_max_size, antialias, fn): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am curious why we have two test_image_correctness functions for cvcuda. If this function test cvcuda specifically, then we don't need to add "make_image_cvcuda" into the "make_ input" for the function in https://github.com/pytorch/vision/pull/9302/changes#diff-9c2dde92db86c123fee225e39b7c1ef96e08a3e79a9dcc9a2d68b21ed51a81d0R896 |
||
| if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)): | ||
| return | ||
|
|
||
| image = make_image_cvcuda(self.INPUT_SIZE, dtype=torch.uint8) | ||
| actual = fn(image, size=size, interpolation=interpolation, **max_size_kwarg, antialias=antialias) | ||
| expected = fn( | ||
| F.cvcuda_to_tensor(image), size=size, interpolation=interpolation, **max_size_kwarg, antialias=antialias | ||
| ) | ||
|
|
||
| # assert_close will squeeze the batch dimension off the CV-CUDA tensor so we convert ahead of time | ||
| actual = F.cvcuda_to_tensor(actual) | ||
|
|
||
| atol = 1 | ||
| if antialias: | ||
| # cvcuda.hq_resize is accurate within 9 for the tests | ||
| atol = 9 | ||
| elif interpolation == transforms.InterpolationMode.BICUBIC: | ||
| # the CV-CUDA bicubic interpolation differs significantly | ||
| # importantly, this is only the edge boundaries | ||
| # visually, there is no noticable difference | ||
| atol = 91 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the reason for choosing 91? |
||
| assert_close(actual, expected, atol=atol, rtol=0) | ||
|
|
||
| def _reference_resize_bounding_boxes(self, bounding_boxes, format, *, size, max_size=None): | ||
| old_height, old_width = bounding_boxes.canvas_size | ||
|
|
@@ -964,11 +1033,26 @@ def test_keypoints_correctness(self, size, use_max_size, fn): | |
| @pytest.mark.parametrize("interpolation", set(transforms.InterpolationMode) - set(INTERPOLATION_MODES)) | ||
| @pytest.mark.parametrize( | ||
| "make_input", | ||
| [make_image_tensor, make_image_pil, make_image, make_video], | ||
| [ | ||
| make_image_tensor, | ||
| make_image_pil, | ||
| make_image, | ||
| make_video, | ||
| pytest.param( | ||
| make_image_cvcuda, | ||
| marks=pytest.mark.needs_cvcuda, | ||
| ), | ||
| ], | ||
| ) | ||
| def test_pil_interpolation_compat_smoke(self, interpolation, make_input): | ||
| input = make_input(self.INPUT_SIZE) | ||
|
|
||
| if make_input is make_image_cvcuda and interpolation in { | ||
| transforms.InterpolationMode.BOX, | ||
| transforms.InterpolationMode.LANCZOS, | ||
| }: | ||
| pytest.skip("CV-CUDA may support box and lanczos for certain configurations of resize") | ||
|
|
||
| with ( | ||
| contextlib.nullcontext() | ||
| if isinstance(input, PIL.Image.Image) | ||
|
|
@@ -997,6 +1081,10 @@ def test_functional_pil_antialias_warning(self): | |
| make_detection_masks, | ||
| make_video, | ||
| make_keypoints, | ||
| pytest.param( | ||
| make_image_cvcuda, | ||
| marks=pytest.mark.needs_cvcuda, | ||
| ), | ||
| ], | ||
| ) | ||
| def test_max_size_error(self, size, make_input): | ||
|
|
@@ -1040,6 +1128,10 @@ def test_max_size_error(self, size, make_input): | |
| make_detection_masks, | ||
| make_video, | ||
| make_keypoints, | ||
| pytest.param( | ||
| make_image_cvcuda, | ||
| marks=pytest.mark.needs_cvcuda, | ||
| ), | ||
| ], | ||
| ) | ||
| def test_resize_size_none(self, input_size, max_size, expected_size, make_input): | ||
|
|
@@ -1050,7 +1142,16 @@ def test_resize_size_none(self, input_size, max_size, expected_size, make_input) | |
| @pytest.mark.parametrize("interpolation", INTERPOLATION_MODES) | ||
| @pytest.mark.parametrize( | ||
| "make_input", | ||
| [make_image_tensor, make_image_pil, make_image, make_video], | ||
| [ | ||
| make_image_tensor, | ||
| make_image_pil, | ||
| make_image, | ||
| make_video, | ||
| pytest.param( | ||
| make_image_cvcuda, | ||
| marks=pytest.mark.needs_cvcuda, | ||
| ), | ||
| ], | ||
| ) | ||
| def test_interpolation_int(self, interpolation, make_input): | ||
| input = make_input(self.INPUT_SIZE) | ||
|
|
@@ -1114,6 +1215,10 @@ def test_noop(self, size, make_input): | |
| make_detection_masks, | ||
| make_video, | ||
| make_keypoints, | ||
| pytest.param( | ||
| make_image_cvcuda, | ||
| marks=pytest.mark.needs_cvcuda, | ||
| ), | ||
| ], | ||
| ) | ||
| def test_no_regression_5405(self, make_input): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,6 +28,7 @@ | |
|
|
||
| from ._utils import ( | ||
| _FillTypeJIT, | ||
| _get_cvcuda_interp, | ||
| _get_kernel, | ||
| _import_cvcuda, | ||
| _is_cvcuda_available, | ||
|
|
@@ -401,6 +402,82 @@ def __resize_image_pil_dispatch( | |
| return _resize_image_pil(image, size=size, interpolation=interpolation, max_size=max_size) | ||
|
|
||
|
|
||
| _dtype_to_format_cvcuda: dict["cvcuda.Type", "cvcuda.Format"] = {} | ||
|
|
||
|
|
||
| def _resize_image_cvcuda( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to take a closer look at the implementation of this function @NicolasHug I looked through the implementation and the logic looks good. However, we might need to: |
||
| image: "cvcuda.Tensor", | ||
| size: Optional[list[int]], | ||
| interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, | ||
| max_size: Optional[int] = None, | ||
| antialias: Optional[bool] = True, | ||
| ) -> "cvcuda.Tensor": | ||
| cvcuda = _import_cvcuda() | ||
|
|
||
| if not _dtype_to_format_cvcuda: | ||
| _dtype_to_format_cvcuda[cvcuda.Type.U8] = cvcuda.Format.U8 | ||
| _dtype_to_format_cvcuda[cvcuda.Type.U16] = cvcuda.Format.U16 | ||
| _dtype_to_format_cvcuda[cvcuda.Type.U32] = cvcuda.Format.U32 | ||
| _dtype_to_format_cvcuda[cvcuda.Type.S8] = cvcuda.Format.S8 | ||
| _dtype_to_format_cvcuda[cvcuda.Type.S16] = cvcuda.Format.S16 | ||
| _dtype_to_format_cvcuda[cvcuda.Type.S32] = cvcuda.Format.S32 | ||
| _dtype_to_format_cvcuda[cvcuda.Type.F32] = cvcuda.Format.F32 | ||
| _dtype_to_format_cvcuda[cvcuda.Type.F64] = cvcuda.Format.F64 | ||
|
|
||
| interp = _get_cvcuda_interp(interpolation) | ||
| # hamming error for parity to resize_image | ||
| if interp == cvcuda.Interp.HAMMING: | ||
| raise NotImplementedError("Unsupported interpolation for CV-CUDA resize, got hamming.") | ||
|
|
||
| # match the antialias behavior of resize_image | ||
| if not (interp == cvcuda.Interp.LINEAR or interp == cvcuda.Interp.CUBIC): | ||
| antialias = False | ||
|
|
||
| old_height, old_width = image.shape[1], image.shape[2] | ||
| new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size) | ||
|
|
||
| # No resize needed if dimensions match | ||
| if new_height == old_height and new_width == old_width: | ||
| return image | ||
|
|
||
| # antialias is only supported for cvcuda.hq_resize, if set to true (which is also default) | ||
| # we will fast-track to use hq_resize (also matchs the size parameter) | ||
| if antialias: | ||
| return cvcuda.hq_resize( | ||
| image, | ||
| out_size=(new_height, new_width), | ||
| interpolation=interp, | ||
| antialias=antialias, | ||
| ) | ||
|
|
||
| # if not using antialias, we will use cvcuda.resize/pillowresize instead | ||
| # resize requires that the shape has the same dimensions as the input | ||
| # CV-CUDA tensors are already in NHWC format so we can do a simple tuple creation | ||
| shape = image.shape | ||
| new_shape = (shape[0], new_height, new_width, shape[3]) | ||
|
|
||
| # bicubic mode is not accurate when using cvcuda.resize | ||
| # cvcuda.pillowresize resolves some of the errors | ||
| if interp == cvcuda.Interp.CUBIC: | ||
| return cvcuda.pillowresize( | ||
| image, | ||
| shape=new_shape, | ||
| format=_dtype_to_format_cvcuda[image.dtype], | ||
| interp=interp, | ||
| ) | ||
|
|
||
| # otherwise we will use cvcuda.resize | ||
| return cvcuda.resize( | ||
| image, | ||
| shape=new_shape, | ||
| interp=interp, | ||
| ) | ||
|
|
||
|
|
||
| if CVCUDA_AVAILABLE: | ||
| _register_kernel_internal(resize, _import_cvcuda().Tensor)(_resize_image_cvcuda) | ||
|
|
||
|
|
||
| def resize_mask(mask: torch.Tensor, size: Optional[list[int]], max_size: Optional[int] = None) -> torch.Tensor: | ||
| if mask.ndim < 3: | ||
| mask = mask.unsqueeze(0) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure how to verify the
atol = 9. What is the reason to pick 9?