diff --git a/docs/source/tv_tensors.rst b/docs/source/tv_tensors.rst index cb8a3c45fa9..d292012fdf8 100644 --- a/docs/source/tv_tensors.rst +++ b/docs/source/tv_tensors.rst @@ -21,6 +21,7 @@ info. Image Video + KeyPoints BoundingBoxFormat BoundingBoxes Mask diff --git a/gallery/transforms/plot_tv_tensors.py b/gallery/transforms/plot_tv_tensors.py index 5bce37aa374..2c6ebbf9031 100644 --- a/gallery/transforms/plot_tv_tensors.py +++ b/gallery/transforms/plot_tv_tensors.py @@ -46,11 +46,12 @@ # Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function # for the input data. # -# :mod:`torchvision.tv_tensors` supports four types of TVTensors: +# :mod:`torchvision.tv_tensors` supports five types of TVTensors: # # * :class:`~torchvision.tv_tensors.Image` # * :class:`~torchvision.tv_tensors.Video` # * :class:`~torchvision.tv_tensors.BoundingBoxes` +# * :class:`~torchvision.tv_tensors.KeyPoints` # * :class:`~torchvision.tv_tensors.Mask` # # What can I do with a TVTensor? @@ -96,6 +97,7 @@ # :class:`~torchvision.tv_tensors.BoundingBoxes` requires the coordinate format as well as the size of the # corresponding image (``canvas_size``) alongside the actual values. These # metadata are required to properly transform the bounding boxes. +# In a similar fashion, :class:`~torchvision.tv_tensors.KeyPoints` also require the ``canvas_size`` metadata to be added. bboxes = tv_tensors.BoundingBoxes( [[17, 16, 344, 495], [0, 10, 0, 10]], @@ -104,6 +106,13 @@ ) print(bboxes) + +keypoints = tv_tensors.KeyPoints( + [[17, 16], [344, 495], [0, 10], [0, 10]], + canvas_size=image.shape[-2:] +) +print(keypoints) + # %% # Using ``tv_tensors.wrap()`` # ^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/test/common_utils.py b/test/common_utils.py index b3a26dfd441..bf0fe92ae3e 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -8,6 +8,7 @@ import shutil import sys import tempfile +from typing import Sequence, Tuple import warnings from subprocess import CalledProcessError, check_output, STDOUT @@ -400,6 +401,20 @@ def make_image_pil(*args, **kwargs): return to_pil_image(make_image(*args, **kwargs)) +def make_keypoints( + canvas_size: Tuple[int, int] = DEFAULT_SIZE, *, num_points: int | Sequence[int] = 4, dtype=None, device='cpu' +) -> tv_tensors.KeyPoints: + """Make the KeyPoints for testing purposes""" + if isinstance(num_points, int): + num_points = [num_points] + single_coord_shape: Tuple[int, ...] = tuple(num_points) + (1,) + y = torch.randint(0, canvas_size[0] - 1, single_coord_shape, dtype=dtype, device=device) + x = torch.randint(0, canvas_size[1] - 1, single_coord_shape, dtype=dtype, device=device) + points = torch.cat((x, y), dim=-1) + keypoints = tv_tensors.KeyPoints(points, canvas_size=canvas_size) + return keypoints + + def make_bounding_boxes( canvas_size=DEFAULT_SIZE, *, diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 94d90b9e2f6..701d668b6f2 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -31,6 +31,7 @@ make_image, make_image_pil, make_image_tensor, + make_keypoints, make_segmentation_mask, make_video, make_video_tensor, @@ -232,6 +233,7 @@ def check_functional_kernel_signature_match(functional, *, kernel, input_type): # explicitly passed to the kernel. explicit_metadata = { tv_tensors.BoundingBoxes: {"format", "canvas_size"}, + tv_tensors.KeyPoints: {"canvas_size"} } kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())] @@ -336,6 +338,18 @@ def _make_transform_sample(transform, *, image_or_video, adapter): canvas_size=size, device=device, ), + keypoints=make_keypoints(canvas_size=size), keypoints_degenerate=tv_tensors.KeyPoints( + [ + [0, 1], # left edge + [1, 0], # top edge + [0, 0], # top left corner + [size[1], 1], # right edge + [size[1], 0], # top right corner + [1, size[0]], # bottom edge + [0, size[0]], # bottom left corner + [size[1], size[0]] # bottom right corner + ], canvas_size=size, device=device + ), detection_mask=make_detection_masks(size, device=device), segmentation_mask=make_segmentation_mask(size, device=device), int=0, @@ -689,6 +703,7 @@ def test_functional(self, size, make_input): (F.resize_image, torch.Tensor), (F._geometry._resize_image_pil, PIL.Image.Image), (F.resize_image, tv_tensors.Image), + (F.resize_keypoints, tv_tensors.KeyPoints), (F.resize_bounding_boxes, tv_tensors.BoundingBoxes), (F.resize_mask, tv_tensors.Mask), (F.resize_video, tv_tensors.Video), @@ -1044,6 +1059,7 @@ def test_functional(self, make_input): (F.horizontal_flip_image, torch.Tensor), (F._geometry._horizontal_flip_image_pil, PIL.Image.Image), (F.horizontal_flip_image, tv_tensors.Image), + (F.horizontal_flip_keypoints, tv_tensors.KeyPoints), (F.horizontal_flip_bounding_boxes, tv_tensors.BoundingBoxes), (F.horizontal_flip_mask, tv_tensors.Mask), (F.horizontal_flip_video, tv_tensors.Video), @@ -1214,6 +1230,7 @@ def test_functional(self, make_input): (F.affine_image, torch.Tensor), (F._geometry._affine_image_pil, PIL.Image.Image), (F.affine_image, tv_tensors.Image), + (F.affine_keypoints, tv_tensors.KeyPoints), (F.affine_bounding_boxes, tv_tensors.BoundingBoxes), (F.affine_mask, tv_tensors.Mask), (F.affine_video, tv_tensors.Video), @@ -1496,6 +1513,7 @@ def test_functional(self, make_input): (F.vertical_flip_image, torch.Tensor), (F._geometry._vertical_flip_image_pil, PIL.Image.Image), (F.vertical_flip_image, tv_tensors.Image), + (F.vertical_flip_keypoints, tv_tensors.KeyPoints), (F.vertical_flip_bounding_boxes, tv_tensors.BoundingBoxes), (F.vertical_flip_mask, tv_tensors.Mask), (F.vertical_flip_video, tv_tensors.Video), @@ -1638,6 +1656,7 @@ def test_functional(self, make_input): (F.rotate_image, torch.Tensor), (F._geometry._rotate_image_pil, PIL.Image.Image), (F.rotate_image, tv_tensors.Image), + (F.rotate_keypoints, tv_tensors.KeyPoints), (F.rotate_bounding_boxes, tv_tensors.BoundingBoxes), (F.rotate_mask, tv_tensors.Mask), (F.rotate_video, tv_tensors.Video), @@ -2343,7 +2362,9 @@ def test_error(self, T): F.to_pil_image(imgs[0]), tv_tensors.Mask(torch.rand(12, 12)), tv_tensors.BoundingBoxes(torch.rand(2, 4), format="XYXY", canvas_size=12), + tv_tensors.KeyPoints(torch.rand(2, 2), canvas_size=(12, 12)) ): + print(type(input_with_bad_type), cutmix_mixup) with pytest.raises(ValueError, match="does not support PIL images, "): cutmix_mixup(input_with_bad_type) @@ -2751,8 +2772,9 @@ def test_functional_signature(self, kernel, input_type): check_functional_kernel_signature_match(F.elastic, kernel=kernel, input_type=input_type) @pytest.mark.parametrize( - "make_input", - [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video], + "make_input", [ + make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video, make_keypoints + ], ) def test_displacement_error(self, make_input): input = make_input() @@ -2764,8 +2786,10 @@ def test_displacement_error(self, make_input): F.elastic(input, displacement=torch.rand(F.get_size(input))) @pytest.mark.parametrize( - "make_input", - [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video], + "make_input", [ + make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video, + make_keypoints + ], ) # ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image @pytest.mark.parametrize("size", [(163, 163), (72, 333), (313, 95)]) @@ -3471,7 +3495,7 @@ def _sample_input_adapter(self, transform, input, device): adapted_input = {} image_or_video_found = False for key, value in input.items(): - if isinstance(value, (tv_tensors.BoundingBoxes, tv_tensors.Mask)): + if isinstance(value, (tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, tv_tensors.Mask)): # AA transforms don't support bounding boxes or masks continue elif check_type(value, (tv_tensors.Image, tv_tensors.Video, is_pure_tensor, PIL.Image.Image)): @@ -6271,3 +6295,23 @@ def test_different_sizes(self, make_input1, make_input2, query): def test_no_valid_input(self, query): with pytest.raises(TypeError, match="No image"): query(["blah"]) + + @pytest.mark.parametrize( + 'boxes', [ + tv_tensors.BoundingBoxes(torch.tensor([[1, 1, 2, 2]]), format="XYXY", canvas_size=(4, 4)) + ] + ) + def test_convert_bounding_boxes_to_points(self, boxes: tv_tensors.BoundingBoxes): + # TODO: this test can't handle rotated boxes yet + kp = F.convert_bounding_boxes_to_points(boxes) + assert kp.shape == boxes.shape + (2, ) + assert kp.dtype == boxes.dtype + # kp is a list of A, B, C, D polygons. + # If we use A | C, we should get back the XYXY format of bounding box + reconverted = torch.cat([kp[..., 0, :], kp[..., 2, :]], dim=-1) + reconverted_bbox = F.convert_bounding_box_format( + tv_tensors.BoundingBoxes( + reconverted, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=kp.canvas_size + ), new_format=boxes.format + ) + assert (reconverted_bbox == boxes).all(), f"Invalid reconversion : {reconverted_bbox}" diff --git a/test/test_transforms_v2_utils.py b/test/test_transforms_v2_utils.py index 53222c6a2c8..813a3cd93e6 100644 --- a/test/test_transforms_v2_utils.py +++ b/test/test_transforms_v2_utils.py @@ -4,7 +4,7 @@ import torch import torchvision.transforms.v2._utils -from common_utils import DEFAULT_SIZE, make_bounding_boxes, make_detection_masks, make_image +from common_utils import DEFAULT_SIZE, make_bounding_boxes, make_detection_masks, make_image, make_keypoints from torchvision import tv_tensors from torchvision.transforms.v2._utils import has_all, has_any @@ -14,29 +14,32 @@ IMAGE = make_image(DEFAULT_SIZE, color_space="RGB") BOUNDING_BOX = make_bounding_boxes(DEFAULT_SIZE, format=tv_tensors.BoundingBoxFormat.XYXY) MASK = make_detection_masks(DEFAULT_SIZE) +KEYPOINTS = make_keypoints(DEFAULT_SIZE) @pytest.mark.parametrize( ("sample", "types", "expected"), [ - ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image,), True), - ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes,), True), - ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Mask,), True), - ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), True), - ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), True), - ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True), - ((MASK,), (tv_tensors.Image, tv_tensors.BoundingBoxes), False), - ((BOUNDING_BOX,), (tv_tensors.Image, tv_tensors.Mask), False), - ((IMAGE,), (tv_tensors.BoundingBoxes, tv_tensors.Mask), False), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image,), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes,), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Mask,), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image, tv_tensors.BoundingBoxes), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image, tv_tensors.Mask), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.KeyPoints,), True), + ((MASK,), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.KeyPoints), False), + ((BOUNDING_BOX,), (tv_tensors.Image, tv_tensors.Mask, tv_tensors.KeyPoints), False), + ((IMAGE,), (tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints), False), + ((KEYPOINTS,), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False), ( - (IMAGE, BOUNDING_BOX, MASK), - (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), + (IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), + (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints), True, ), - ((), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False), - ((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, tv_tensors.Image),), True), - ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False), - ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True), + ((), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints), False), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (lambda obj: isinstance(obj, tv_tensors.Image),), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (lambda _: False,), False), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (lambda _: True,), True), ((IMAGE,), (tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor), True), ( (torch.Tensor(IMAGE),), @@ -57,15 +60,18 @@ def test_has_any(sample, types, expected): @pytest.mark.parametrize( ("sample", "types", "expected"), [ - ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image,), True), - ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes,), True), - ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Mask,), True), - ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), True), - ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), True), - ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image,), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes,), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Mask,), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image, tv_tensors.BoundingBoxes), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Image, tv_tensors.Mask), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.Mask, tv_tensors.KeyPoints), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes, tv_tensors.KeyPoints), True), + ((IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), (tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints), True), ( - (IMAGE, BOUNDING_BOX, MASK), - (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), + (IMAGE, BOUNDING_BOX, MASK, KEYPOINTS), + (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints), True, ), ((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), False), diff --git a/test/test_tv_tensors.py b/test/test_tv_tensors.py index a8e59ab7531..a29c942db67 100644 --- a/test/test_tv_tensors.py +++ b/test/test_tv_tensors.py @@ -2,7 +2,7 @@ import pytest import torch -from common_utils import assert_equal, make_bounding_boxes, make_image, make_segmentation_mask, make_video +from common_utils import assert_equal, make_bounding_boxes, make_image, make_keypoints, make_segmentation_mask, make_video from PIL import Image from torchvision import tv_tensors @@ -49,6 +49,20 @@ def test_bbox_dim_error(): tv_tensors.BoundingBoxes(data_3d, format="XYXY", canvas_size=(32, 32)) +@pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 2)), [[0, 0,], [2, 2,]], [1, 2,]]) +def test_keypoints_instance(data): + kpoint = tv_tensors.KeyPoints(data, canvas_size=(32, 32)) + assert isinstance(kpoint, tv_tensors.KeyPoints) + assert type(kpoint) is tv_tensors.KeyPoints + assert kpoint.shape[-1] == 2 + + +def test_keypoints_shape_error(): + data_3d = [(0, 1, 2)] + with pytest.raises(ValueError, match="shape"): + tv_tensors.KeyPoints(torch.tensor(data_3d), canvas_size=(11, 7)) + + @pytest.mark.parametrize( ("data", "input_requires_grad", "expected_requires_grad"), [ @@ -68,7 +82,9 @@ def test_new_requires_grad(data, input_requires_grad, expected_requires_grad): assert tv_tensor.requires_grad is expected_requires_grad -@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) +@pytest.mark.parametrize("make_input", [ + make_image, make_bounding_boxes, make_segmentation_mask, make_video, make_keypoints +]) def test_isinstance(make_input): assert isinstance(make_input(), torch.Tensor) @@ -80,7 +96,9 @@ def test_wrapping_no_copy(): assert image.data_ptr() == tensor.data_ptr() -@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) +@pytest.mark.parametrize("make_input", [ + make_image, make_bounding_boxes, make_segmentation_mask, make_video, make_keypoints +]) def test_to_wrapping(make_input): dp = make_input() @@ -90,7 +108,9 @@ def test_to_wrapping(make_input): assert dp_to.dtype is torch.float64 -@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) +@pytest.mark.parametrize("make_input", [ + make_image, make_bounding_boxes, make_segmentation_mask, make_video, make_keypoints +]) @pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"]) def test_to_tv_tensor_reference(make_input, return_type): tensor = torch.rand((3, 16, 16), dtype=torch.float64) @@ -104,7 +124,9 @@ def test_to_tv_tensor_reference(make_input, return_type): assert type(tensor) is torch.Tensor -@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) +@pytest.mark.parametrize("make_input", [ + make_image, make_bounding_boxes, make_segmentation_mask, make_video, make_keypoints +]) @pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"]) def test_clone_wrapping(make_input, return_type): dp = make_input() @@ -116,7 +138,9 @@ def test_clone_wrapping(make_input, return_type): assert dp_clone.data_ptr() != dp.data_ptr() -@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) +@pytest.mark.parametrize("make_input", [ + make_image, make_bounding_boxes, make_segmentation_mask, make_video, make_keypoints +]) @pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"]) def test_requires_grad__wrapping(make_input, return_type): dp = make_input(dtype=torch.float) @@ -131,7 +155,9 @@ def test_requires_grad__wrapping(make_input, return_type): assert dp_requires_grad.requires_grad -@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) +@pytest.mark.parametrize("make_input", [ + make_image, make_bounding_boxes, make_segmentation_mask, make_video, make_keypoints +]) @pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"]) def test_detach_wrapping(make_input, return_type): dp = make_input(dtype=torch.float).requires_grad_(True) @@ -148,18 +174,25 @@ def test_force_subclass_with_metadata(return_type): # Largely the same as above, we additionally check that the metadata is preserved format, canvas_size = "XYXY", (32, 32) bbox = tv_tensors.BoundingBoxes([[0, 0, 5, 5], [2, 2, 7, 7]], format=format, canvas_size=canvas_size) + kpoints = tv_tensors.KeyPoints([[0, 0], [2, 2]], canvas_size=canvas_size) tv_tensors.set_return_type(return_type) bbox = bbox.clone() + kpoints = kpoints.clone() if return_type == "TVTensor": + assert kpoints.canvas_size == canvas_size assert bbox.format, bbox.canvas_size == (format, canvas_size) bbox = bbox.to(torch.float64) + kpoints = kpoints.to(torch.float64) if return_type == "TVTensor": + assert kpoints.canvas_size == canvas_size assert bbox.format, bbox.canvas_size == (format, canvas_size) bbox = bbox.detach() + kpoints = kpoints.detach() if return_type == "TVTensor": + assert kpoints.canvas_size == canvas_size assert bbox.format, bbox.canvas_size == (format, canvas_size) if torch.cuda.is_available(): @@ -168,14 +201,20 @@ def test_force_subclass_with_metadata(return_type): assert bbox.format, bbox.canvas_size == (format, canvas_size) assert not bbox.requires_grad + assert not kpoints.requires_grad bbox.requires_grad_(True) + kpoints.requires_grad_(True) if return_type == "TVTensor": + assert kpoints.canvas_size == canvas_size assert bbox.format, bbox.canvas_size == (format, canvas_size) assert bbox.requires_grad + assert kpoints.requires_grad tv_tensors.set_return_type("tensor") -@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) +@pytest.mark.parametrize("make_input", [ + make_image, make_bounding_boxes, make_segmentation_mask, make_video, make_keypoints +]) @pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"]) def test_other_op_no_wrapping(make_input, return_type): dp = make_input() @@ -187,7 +226,9 @@ def test_other_op_no_wrapping(make_input, return_type): assert type(output) is (type(dp) if return_type == "TVTensor" else torch.Tensor) -@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) +@pytest.mark.parametrize("make_input", [ + make_image, make_bounding_boxes, make_segmentation_mask, make_video, make_keypoints +]) @pytest.mark.parametrize( "op", [ @@ -204,7 +245,9 @@ def test_no_tensor_output_op_no_wrapping(make_input, op): assert type(output) is not type(dp) -@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) +@pytest.mark.parametrize("make_input", [ + make_image, make_bounding_boxes, make_segmentation_mask, make_video, make_keypoints +]) @pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"]) def test_inplace_op_no_wrapping(make_input, return_type): dp = make_input() diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index 02a487cabd3..980e27647f7 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -90,7 +90,7 @@ def __init__( self._log_ratio = torch.log(torch.tensor(self.ratio)) def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: - if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)): + if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, tv_tensors.Mask)): warnings.warn( f"{type(self).__name__}() is currently passing through inputs of type " f"tv_tensors.{type(inpt).__name__}. This will likely change in the future." @@ -157,7 +157,7 @@ def forward(self, *inputs): flat_inputs, spec = tree_flatten(inputs) needs_transform_list = self._needs_transform_list(flat_inputs) - if has_any(flat_inputs, PIL.Image.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask): + if has_any(flat_inputs, PIL.Image.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints): raise ValueError(f"{type(self).__name__}() does not support PIL images, bounding boxes and masks.") labels = self._labels_getter(inputs) diff --git a/torchvision/transforms/v2/_auto_augment.py b/torchvision/transforms/v2/_auto_augment.py index c743eb40775..52707af1f2e 100644 --- a/torchvision/transforms/v2/_auto_augment.py +++ b/torchvision/transforms/v2/_auto_augment.py @@ -46,7 +46,7 @@ def _get_random_item(self, dct: dict[str, tuple[Callable, bool]]) -> tuple[str, def _flatten_and_extract_image_or_video( self, inputs: Any, - unsupported_types: tuple[type, ...] = (tv_tensors.BoundingBoxes, tv_tensors.Mask), + unsupported_types: tuple[type, ...] = (tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints), ) -> tuple[tuple[list[Any], TreeSpec, int], ImageOrVideo]: flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0]) needs_transform_list = self._needs_transform_list(flat_inputs) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 86c00e28a66..e1ed436ba36 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -357,7 +357,7 @@ def __init__(self, size: Union[int, Sequence[int]]) -> None: self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: - if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)): + if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, tv_tensors.Mask)): warnings.warn( f"{type(self).__name__}() is currently passing through inputs of type " f"tv_tensors.{type(inpt).__name__}. This will likely change in the future." @@ -402,7 +402,7 @@ def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) self.vertical_flip = vertical_flip def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: - if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)): + if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, tv_tensors.Mask)): warnings.warn( f"{type(self).__name__}() is currently passing through inputs of type " f"tv_tensors.{type(inpt).__name__}. This will likely change in the future." diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index dfd521b13be..d6d61fa6d6c 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -341,9 +341,9 @@ def transform(self, inpt: Any, params: dict[str, Any]) -> Any: class SanitizeBoundingBoxes(Transform): - """Remove degenerate/invalid bounding boxes and their corresponding labels and masks. + """Remove degenerate/invalid bounding boxes and their corresponding labels, masks and keypoints. - This transform removes bounding boxes and their associated labels/masks that: + This transform removes bounding boxes and their associated labels, masks and keypoints that: - are below a given ``min_size`` or ``min_area``: by default this also removes degenerate boxes that have e.g. X2 <= X1. - have any coordinate outside of their corresponding image. You may want to @@ -359,6 +359,14 @@ class SanitizeBoundingBoxes(Transform): may modify bounding boxes but once at the end should be enough in most cases. + .. note:: + This transform requires that any :class:`~torchvision.tv_tensor.KeyPoints` or + :class:`~torchvision.tv_tensor.Mask` provided has to match the bounding boxes in shape. + + If the bounding boxes are of shape ``[N, K]``, then the + KeyPoints have to be of shape ``[N, ..., 2]`` or ``[N, 2]`` + and the masks have to be of shape ``[N, ..., H, W]`` or ``[N, H, W]`` + Args: min_size (float, optional): The size below which bounding boxes are removed. Default is 1. min_area (float, optional): The area below which bounding boxes are removed. Default is 1. @@ -438,10 +446,15 @@ def forward(self, *inputs: Any) -> Any: return tree_unflatten(flat_outputs, spec) def transform(self, inpt: Any, params: dict[str, Any]) -> Any: + # For every object in the flattened input of the `forward` method, we apply transform + # The params contain the list of valid indices of the (N, K) bbox set + + # We suppose here that any KeyPoints or Masks TVTensors is of shape (N, ..., 2) and (N, ..., H, W) respectively + # TODO: check this. is_label = params["labels"] is not None and any(inpt is label for label in params["labels"]) - is_bounding_boxes_or_mask = isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)) + is_bbox_mask_or_kpoints = isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask, tv_tensors.KeyPoints)) - if not (is_label or is_bounding_boxes_or_mask): + if not (is_label or is_bbox_mask_or_kpoints): return inpt output = inpt[params["valid"]] diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index c4371ce0953..34fb8ee4170 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -5,7 +5,7 @@ from collections.abc import Sequence from contextlib import suppress -from typing import Any, Callable, Literal +from typing import Any, Callable, Literal, Sequence, Iterable import PIL.Image import torch @@ -165,6 +165,20 @@ def get_bounding_boxes(flat_inputs: list[Any]) -> tv_tensors.BoundingBoxes: raise ValueError("No bounding boxes were found in the sample") +def get_all_keypoints(flat_inputs: list[Any]) -> Iterable[tv_tensors.KeyPoints]: + """Yields all KeyPoints in the input. + + Raises: + ValueError: No KeyPoints can be found + """ + generator = (inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.KeyPoints)) + try: + yield next(generator) + except StopIteration: + raise ValueError("No Keypoints were found in the sample.") + return generator + + def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]: """Return Channel, Height, and Width.""" chws = { @@ -194,6 +208,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]: tv_tensors.Video, tv_tensors.Mask, tv_tensors.BoundingBoxes, + tv_tensors.KeyPoints, ), ) } diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index d5705d55c4b..e32ef73f7c1 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -5,6 +5,7 @@ from ._meta import ( clamp_bounding_boxes, convert_bounding_box_format, + convert_bounding_boxes_to_points, get_dimensions_image, get_dimensions_video, get_dimensions, @@ -69,21 +70,25 @@ affine, affine_bounding_boxes, affine_image, + affine_keypoints, affine_mask, affine_video, center_crop, center_crop_bounding_boxes, center_crop_image, + center_crop_keypoints, center_crop_mask, center_crop_video, crop, crop_bounding_boxes, crop_image, + crop_keypoints, crop_mask, crop_video, elastic, elastic_bounding_boxes, elastic_image, + elastic_keypoints, elastic_mask, elastic_transform, elastic_video, @@ -94,13 +99,16 @@ horizontal_flip, horizontal_flip_bounding_boxes, horizontal_flip_image, + horizontal_flip_keypoints, horizontal_flip_mask, horizontal_flip_video, pad, pad_bounding_boxes, pad_image, + pad_keypoints, pad_mask, pad_video, + perspectice_keypoints, perspective, perspective_bounding_boxes, perspective_image, @@ -109,16 +117,19 @@ resize, resize_bounding_boxes, resize_image, + resize_keypoints, resize_mask, resize_video, resized_crop, resized_crop_bounding_boxes, resized_crop_image, + resized_crop_keypoints, resized_crop_mask, resized_crop_video, rotate, rotate_bounding_boxes, rotate_image, + rotate_keypoints, rotate_mask, rotate_video, ten_crop, @@ -127,6 +138,7 @@ vertical_flip, vertical_flip_bounding_boxes, vertical_flip_image, + vertical_flip_keypoints, vertical_flip_mask, vertical_flip_video, vflip, @@ -143,6 +155,7 @@ normalize_image, normalize_video, sanitize_bounding_boxes, + sanitize_keypoints, to_dtype, to_dtype_image, to_dtype_video, diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 8303019e011..a74a211b9e7 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -23,7 +23,7 @@ from torchvision.utils import _log_api_usage_once -from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format +from ._meta import _get_size_image_pil, clamp_bounding_boxes, clamp_keypoints, convert_bounding_box_format from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal @@ -66,6 +66,17 @@ def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor: return horizontal_flip_image(mask) +def horizontal_flip_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int]): + keypoints[..., 0] = keypoints[..., 0].sub_(canvas_size[1]).neg_() + return keypoints + + +@_register_kernel_internal(horizontal_flip, tv_tensors.KeyPoints, tv_tensor_wrapper=False) +def _horizontal_flip_keypoints_dispatch(keypoints: tv_tensors.KeyPoints): + out = horizontal_flip_keypoints(keypoints.as_subclass(torch.Tensor), canvas_size=keypoints.canvas_size) + return tv_tensors.wrap(out, like=keypoints) + + def horizontal_flip_bounding_boxes( bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, canvas_size: tuple[int, int] ) -> torch.Tensor: @@ -123,6 +134,12 @@ def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor: return vertical_flip_image(mask) +@_register_kernel_internal(vertical_flip, tv_tensors.KeyPoints, tv_tensor_wrapper=False) +def vertical_flip_keypoints(keypoints: tv_tensors.KeyPoints): + keypoints[..., 1] = keypoints[..., 1].sub_(keypoints.canvas_size[0]).neg_() + return keypoints + + def vertical_flip_bounding_boxes( bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, canvas_size: tuple[int, int] ) -> torch.Tensor: @@ -334,6 +351,44 @@ def _resize_mask_dispatch( return tv_tensors.wrap(output, like=inpt) +def resize_keypoints( + keypoints: torch.Tensor, + size: Optional[list[int]], + canvas_size: tuple[int, int], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[bool] = True, +): + old_height, old_width = canvas_size + new_height, new_width = _compute_resized_output_size(canvas_size, size=size, max_size=max_size) + + w_ratio = new_width / old_width + h_ratio = new_height / old_height + ratios = torch.tensor([w_ratio, h_ratio], device=keypoints.device) + keypoints = keypoints.mul(ratios).to(keypoints.dtype) + + return keypoints, (new_height, new_width) + + +@_register_kernel_internal(resize, tv_tensors.KeyPoints, tv_tensor_wrapper=False) +def _resize_keypoints_dispatch( + keypoints: tv_tensors.KeyPoints, + size: Optional[list[int]], + interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[bool] = True, +) -> tv_tensors.KeyPoints: + out, canvas_size = resize_keypoints( + keypoints.as_subclass(torch.Tensor), + size, + canvas_size=keypoints.canvas_size, + interpolation=interpolation, + max_size=max_size, + antialias=antialias, + ) + return tv_tensors.wrap(out, like=keypoints, canvas_size=canvas_size) + + def resize_bounding_boxes( bounding_boxes: torch.Tensor, canvas_size: tuple[int, int], @@ -759,6 +814,93 @@ def _affine_image_pil( return _FP.affine(image, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill) +def _affine_keypoints_with_expand( + keypoints: torch.Tensor, + canvas_size: tuple[int, int], + angle: Union[int, float], + translate: list[float], + scale: float, + shear: list[float], + center: Optional[list[float]] = None, + expand: bool = False, +) -> tuple[torch.Tensor, tuple[int, int]]: + if keypoints.numel() == 0: + return keypoints, canvas_size + + original_dtype = keypoints.dtype + keypoints = keypoints.clone() if keypoints.is_floating_point() else keypoints.float() + dtype = keypoints.dtype + device = keypoints.device + + angle, translate, shear, center = _affine_parse_args( + angle, translate, scale, shear, InterpolationMode.NEAREST, center + ) + + if center is None: + height, width = canvas_size + center = [width * 0.5, height * 0.5] + + affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False) + transposed_affine_matrix = ( + torch.tensor( + affine_vector, + dtype=dtype, + device=device, + ) + .reshape(2, 3) + .T + ) + # 1) Unlike bounding box (whose implmentation we stole) we're already a bunch of points. + keypoints = torch.cat([keypoints, torch.ones(keypoints.shape[0], 1, device=device, dtype=dtype)], dim=-1) + # 2) Now let's transform the points using affine matrix + keypoints = torch.matmul(keypoints, transposed_affine_matrix).to(original_dtype) + + return keypoints, canvas_size + + +def affine_keypoints( + keypoints: torch.Tensor, + canvas_size: tuple[int, int], + angle: Union[int, float], + translate: list[float], + scale: float, + shear: list[float], + center: Optional[list[float]] = None, +): + return _affine_keypoints_with_expand( + keypoints=keypoints, + canvas_size=canvas_size, + angle=angle, + translate=translate, + scale=scale, + shear=shear, + center=center, + expand=False, + ) + + +@_register_kernel_internal(affine, tv_tensors.KeyPoints, tv_tensor_wrapper=False) +def _affine_keypoints_dispatch( + inpt: tv_tensors.KeyPoints, + angle: Union[int, float], + translate: list[float], + scale: float, + shear: list[float], + center: Optional[list[float]] = None, + **kwargs, +) -> tv_tensors.KeyPoints: + output, canvas_size = affine_keypoints( + inpt.as_subclass(torch.Tensor), + canvas_size=inpt.canvas_size, + angle=angle, + translate=translate, + scale=scale, + shear=shear, + center=center, + ) + return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) + + def _affine_bounding_boxes_with_expand( bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, @@ -1056,6 +1198,34 @@ def _rotate_image_pil( ) +def rotate_keypoints( + keypoints: tv_tensors.KeyPoints, + angle: float, + interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, + expand: bool = False, + center: Optional[list[float]] = None, + fill: _FillTypeJIT = None, +) -> tuple[torch.Tensor, tuple[int, int]]: + return _affine_keypoints_with_expand( + keypoints=keypoints.as_subclass(torch.Tensor), + canvas_size=keypoints.canvas_size, + angle=-angle, + translate=[0.0, 0.0], + scale=1.0, + shear=[0.0, 0.0], + center=center, + expand=expand, + ) + + +@_register_kernel_internal(rotate, tv_tensors.KeyPoints, tv_tensor_wrapper=False) +def _rotate_keypoints_dispatch( + keypoints: tv_tensors.KeyPoints, angle: float, expand: bool = False, center: Optional[list[float]] = None, **kwargs +) -> tv_tensors.KeyPoints: + out, canvas_size = rotate_keypoints(keypoints, angle, center=center, expand=expand, **kwargs) + return tv_tensors.wrap(out, like=keypoints, canvas_size=canvas_size) + + def rotate_bounding_boxes( bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, @@ -1319,6 +1489,35 @@ def pad_mask( return output +def pad_keypoints( + keypoints: torch.Tensor, canvas_size: tuple[int, int], padding: list[int], padding_mode: str = "constant" +): + SUPPORTED_MODES = ["constant"] + if padding_mode not in SUPPORTED_MODES: + # TODO: add support of other padding modes + raise ValueError( + f"Padding mode '{padding_mode}' is not supported with KeyPoints" + f" (supported modes are {', '.join(SUPPORTED_MODES)})" + ) + left, right, top, bottom = _parse_pad_padding(padding) + pad = torch.tensor([left, top], dtype=keypoints.dtype, device=keypoints.device) + canvas_size = (canvas_size[0] + top + bottom, canvas_size[1] + left + right) + return clamp_keypoints(keypoints + pad, canvas_size), canvas_size + + +@_register_kernel_internal(pad, tv_tensors.KeyPoints, tv_tensor_wrapper=False) +def _pad_keypoints_dispatch( + keypoints: tv_tensors.KeyPoints, padding: list[int], padding_mode: str = "constant", **kwargs +) -> tv_tensors.KeyPoints: + output, canvas_size = pad_keypoints( + keypoints.as_subclass(torch.Tensor), + canvas_size=keypoints.canvas_size, + padding=padding, + padding_mode=padding_mode, + ) + return tv_tensors.wrap(output, like=keypoints, canvas_size=canvas_size) + + def pad_bounding_boxes( bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, @@ -1405,6 +1604,28 @@ def crop_image(image: torch.Tensor, top: int, left: int, height: int, width: int _register_kernel_internal(crop, PIL.Image.Image)(_crop_image_pil) +def crop_keypoints( + keypoints: torch.Tensor, + top: int, + left: int, + height: int, + width: int, +) -> tuple[torch.Tensor, tuple[int, int]]: + + keypoints.sub_(torch.tensor([left, top], dtype=keypoints.dtype, device=keypoints.device)) + canvas_size = (height, width) + + return clamp_keypoints(keypoints, canvas_size=canvas_size), canvas_size + + +@_register_kernel_internal(crop, tv_tensors.KeyPoints, tv_tensor_wrapper=False) +def _crop_keypoints_dispatch( + inpt: tv_tensors.KeyPoints, top: int, left: int, height: int, width: int +) -> tv_tensors.KeyPoints: + out, canvas_size = crop_keypoints(inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width) + return tv_tensors.wrap(out, like=inpt, canvas_size=canvas_size) + + def crop_bounding_boxes( bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, @@ -1578,6 +1799,54 @@ def _perspective_image_pil( return _FP.perspective(image, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill) +def perspectice_keypoints( + keypoints: torch.Tensor, + canvas_size: tuple[int, int], + startpoints: Optional[list[list[int]]], + endpoints: Optional[list[list[int]]], + coefficients: Optional[list[float]] = None, +): + if keypoints.numel() == 0: + return keypoints + dtype = keypoints.dtype if torch.is_floating_point(keypoints) else torch.float32 + device = keypoints.device + + perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients) + + denom = perspective_coeffs[0] * perspective_coeffs[4] - perspective_coeffs[1] * perspective_coeffs[3] + if denom == 0: + raise RuntimeError( + f"Provided perspective_coeffs {perspective_coeffs} can not be inverted to transform keypoints. " + f"Denominator is zero, denom={denom}" + ) + + theta1, theta2 = _compute_perspective_thetas(perspective_coeffs, dtype, device, denom) + keypoints = torch.cat([keypoints, torch.ones(keypoints.shape[0], 1, device=keypoints.device)], dim=-1) + + numer_points = torch.matmul(keypoints, theta1.T) + denom_points = torch.matmul(keypoints, theta2.T) + transformed_points = numer_points.div_(denom_points) + return clamp_keypoints(transformed_points, canvas_size) + + +@_register_kernel_internal(perspective, tv_tensors.KeyPoints, tv_tensor_wrapper=False) +def _perspective_keypoints_dispatch( + inpt: tv_tensors.BoundingBoxes, + startpoints: Optional[list[list[int]]], + endpoints: Optional[list[list[int]]], + coefficients: Optional[list[float]] = None, + **kwargs, +) -> tv_tensors.BoundingBoxes: + output = perspectice_keypoints( + inpt.as_subclass(torch.Tensor), + canvas_size=inpt.canvas_size, + startpoints=startpoints, + endpoints=endpoints, + coefficients=coefficients, + ) + return tv_tensors.wrap(output, like=inpt) + + def perspective_bounding_boxes( bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, @@ -1619,26 +1888,7 @@ def perspective_bounding_boxes( f"Denominator is zero, denom={denom}" ) - inv_coeffs = [ - (perspective_coeffs[4] - perspective_coeffs[5] * perspective_coeffs[7]) / denom, - (-perspective_coeffs[1] + perspective_coeffs[2] * perspective_coeffs[7]) / denom, - (perspective_coeffs[1] * perspective_coeffs[5] - perspective_coeffs[2] * perspective_coeffs[4]) / denom, - (-perspective_coeffs[3] + perspective_coeffs[5] * perspective_coeffs[6]) / denom, - (perspective_coeffs[0] - perspective_coeffs[2] * perspective_coeffs[6]) / denom, - (-perspective_coeffs[0] * perspective_coeffs[5] + perspective_coeffs[2] * perspective_coeffs[3]) / denom, - (-perspective_coeffs[4] * perspective_coeffs[6] + perspective_coeffs[3] * perspective_coeffs[7]) / denom, - (-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom, - ] - - theta1 = torch.tensor( - [[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]], - dtype=dtype, - device=device, - ) - - theta2 = torch.tensor( - [[inv_coeffs[6], inv_coeffs[7], 1.0], [inv_coeffs[6], inv_coeffs[7], 1.0]], dtype=dtype, device=device - ) + theta1, theta2 = _compute_perspective_thetas(perspective_coeffs, dtype, device, denom) # 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners). # Tensor of points has shape (N * 4, 3), where N is the number of bboxes @@ -1672,6 +1922,36 @@ def perspective_bounding_boxes( ).reshape(original_shape) +def _compute_perspective_thetas( + perspective_coeffs: list[float], + dtype: torch.dtype, + device: torch.device, + denom: float, +) -> tuple[torch.Tensor, torch.Tensor]: + inv_coeffs = [ + (perspective_coeffs[4] - perspective_coeffs[5] * perspective_coeffs[7]) / denom, + (-perspective_coeffs[1] + perspective_coeffs[2] * perspective_coeffs[7]) / denom, + (perspective_coeffs[1] * perspective_coeffs[5] - perspective_coeffs[2] * perspective_coeffs[4]) / denom, + (-perspective_coeffs[3] + perspective_coeffs[5] * perspective_coeffs[6]) / denom, + (perspective_coeffs[0] - perspective_coeffs[2] * perspective_coeffs[6]) / denom, + (-perspective_coeffs[0] * perspective_coeffs[5] + perspective_coeffs[2] * perspective_coeffs[3]) / denom, + (-perspective_coeffs[4] * perspective_coeffs[6] + perspective_coeffs[3] * perspective_coeffs[7]) / denom, + (-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom, + ] + + theta1 = torch.tensor( + [[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]], + dtype=dtype, + device=device, + ) + + theta2 = torch.tensor( + [[inv_coeffs[6], inv_coeffs[7], 1.0], [inv_coeffs[6], inv_coeffs[7], 1.0]], dtype=dtype, device=device + ) + + return theta1, theta2 + + @_register_kernel_internal(perspective, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False) def _perspective_bounding_boxes_dispatch( inpt: tv_tensors.BoundingBoxes, @@ -1832,6 +2112,43 @@ def _create_identity_grid(size: tuple[int, int], device: torch.device, dtype: to return base_grid +def elastic_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int], displacement: torch.Tensor) -> torch.Tensor: + expected_shape = (1, canvas_size[0], canvas_size[1], 2) + if not isinstance(displacement, torch.Tensor): + raise TypeError("Argument displacement should be a Tensor") + elif displacement.shape != expected_shape: + raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}") + + if keypoints.numel() == 0: + return keypoints + + device = keypoints.device + dtype = keypoints.dtype if torch.is_floating_point(keypoints) else torch.float32 + + if displacement.dtype != dtype or displacement.device != device: + displacement = displacement.to(dtype=dtype, device=device) + + id_grid = _create_identity_grid(canvas_size, device=device, dtype=dtype) + inv_grid = id_grid.sub_(displacement) + + index_xy = keypoints.to(dtype=torch.long) + index_x, index_y = index_xy[:, 0], index_xy[:, 1] + # Unlike bounding boxes, this may not work well. + index_x.clamp_(0, inv_grid.shape[2] - 1) + index_y.clamp_(0, inv_grid.shape[1] - 1) + + t_size = torch.tensor(canvas_size[::-1], device=displacement.device, dtype=displacement.dtype) + transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5) + + return clamp_keypoints(transformed_points, canvas_size=canvas_size) + + +@_register_kernel_internal(elastic, tv_tensors.KeyPoints, tv_tensor_wrapper=False) +def _elastic_keypoints_dispatch(inpt: tv_tensors.BoundingBoxes, displacement: torch.Tensor, **kwargs): + output = elastic_keypoints(inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, displacement=displacement) + return tv_tensors.wrap(output, like=inpt) + + def elastic_bounding_boxes( bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, @@ -2012,6 +2329,20 @@ def _center_crop_image_pil(image: PIL.Image.Image, output_size: list[int]) -> PI return _crop_image_pil(image, crop_top, crop_left, crop_height, crop_width) +def center_crop_keypoints(inpt: torch.Tensor, canvas_size: tuple[int, int], output_size: list[int]): + crop_height, crop_width = _center_crop_parse_output_size(output_size) + crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *canvas_size) + return crop_keypoints(inpt, top=crop_top, left=crop_left, height=crop_height, width=crop_width) + + +@_register_kernel_internal(center_crop, tv_tensors.KeyPoints, tv_tensor_wrapper=False) +def _center_crop_keypoints_dispatch(inpt: tv_tensors.KeyPoints, output_size: list[int]) -> tv_tensors.KeyPoints: + output, canvas_size = center_crop_keypoints( + inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, output_size=output_size + ) + return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) + + def center_crop_bounding_boxes( bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, @@ -2147,6 +2478,28 @@ def _resized_crop_image_pil_dispatch( ) +def resized_crop_keypoints( + keypoints: torch.Tensor, + top: int, + left: int, + height: int, + width: int, + size: list[int], +) -> tuple[torch.Tensor, tuple[int, int]]: + keypoints, canvas_size = crop_keypoints(keypoints, top, left, height, width) + return resize_keypoints(keypoints, size=size, canvas_size=canvas_size) + + +@_register_kernel_internal(resized_crop, tv_tensors.KeyPoints, tv_tensor_wrapper=False) +def _resized_crop_keypoints_dispatch( + inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int, size: list[int], **kwargs +): + output, canvas_size = resized_crop_keypoints( + inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, size=size + ) + return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size) + + def resized_crop_bounding_boxes( bounding_boxes: torch.Tensor, format: tv_tensors.BoundingBoxFormat, diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 31dae9a1a81..bd3cbd3c699 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -121,6 +121,11 @@ def get_size_bounding_boxes(bounding_box: tv_tensors.BoundingBoxes) -> list[int] return list(bounding_box.canvas_size) +@_register_kernel_internal(get_size, tv_tensors.KeyPoints, tv_tensor_wrapper=False) +def get_size_keypoints(keypoints: tv_tensors.KeyPoints) -> List[int]: + return list(keypoints.canvas_size) + + def get_num_frames(inpt: torch.Tensor) -> int: if torch.jit.is_scripting(): return get_num_frames_video(inpt) @@ -176,6 +181,29 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor: return xyxy +def _xyxy_to_keypoints(bounding_boxes: torch.Tensor) -> torch.Tensor: + return bounding_boxes[:, [[0, 1], [2, 1], [2, 3], [0, 3]]] + + +def convert_bounding_boxes_to_points(bounding_boxes: tv_tensors.BoundingBoxes) -> tv_tensors.KeyPoints: + """Converts a set of bounding boxes to its edge points. + + Args: + bounding_boxes (tv_tensors.BoundingBoxes): A set of ``N`` bounding boxes (of shape ``[N, 4]``) + + Returns: + tv_tensors.KeyPoints: The edges, of shape ``[N, 4, 2]`` + """ + # TODO: support rotated BBOX + bbox = _convert_bounding_box_format( + bounding_boxes.as_subclass(torch.Tensor), + old_format=bounding_boxes.format, + new_format=BoundingBoxFormat.XYXY, + inplace=False, + ) + return tv_tensors.KeyPoints(_xyxy_to_keypoints(bbox), canvas_size=bounding_boxes.canvas_size) + + def _cxcywhr_to_xywhr(cxcywhr: torch.Tensor, inplace: bool) -> torch.Tensor: if not inplace: cxcywhr = cxcywhr.clone() @@ -360,6 +388,16 @@ def _clamp_bounding_boxes( return out_boxes.to(in_dtype) +def clamp_keypoints(inpt: torch.Tensor, canvas_size: Tuple[int, int]) -> torch.Tensor: + if not torch.jit.is_scripting(): + _log_api_usage_once(clamp_bounding_boxes) + dtype = inpt.dtype + inpt = inpt.float() + inpt[..., 0].clamp_(0, canvas_size[1]) + inpt[..., 1].clamp_(0, canvas_size[0]) + return inpt.to(dtype=dtype) + + def clamp_bounding_boxes( inpt: torch.Tensor, format: Optional[BoundingBoxFormat] = None, diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 7e167d788e6..ccd750eba0f 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -320,6 +320,7 @@ def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: return to_dtype_image(video, dtype, scale=scale) +@_register_kernel_internal(to_dtype, tv_tensors.KeyPoints, tv_tensor_wrapper=False) @_register_kernel_internal(to_dtype, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False) @_register_kernel_internal(to_dtype, tv_tensors.Mask, tv_tensor_wrapper=False) def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: bool = False) -> torch.Tensor: @@ -327,6 +328,70 @@ def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: boo return inpt.to(dtype) +def sanitize_keypoints( + keypoints: torch.Tensor, canvas_size: Optional[Tuple[int, int]] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """Removes degenerate/invalid keypoints and returns the corresponding indexing mask. + + This removes the keypoints that are outside of their corresponing image. + You may want to first call :func:`~torchvision.transforms.v2.functional.clam_keypoints` + first to avoid undesired removals. + + .. note:: + Points that touch the edge of the canvas are removed, unlike for :func:`sanitize_bounding_boxes` + + Raises: + ValueError: If the keypoints are not passed as a two dimensional tensor. + + Args: + keypoints (torch.Tensor or class:`~torchvision.tv_tensors.KeyPoints`): The Keypoints being removed + canvas_size (Optional[Tuple[int, int]], optional): The canvas_size of the bounding boxes + (size of the corresponding image/video). + Must be left to none if ``bounding_boxes`` is a :class:`~torchvision.tv_tensors.KeyPoints` object. + + Returns: + out (tuple of Tensors): The subset of valid bounding boxes, and the corresponding indexing mask. + The mask can then be used to subset other tensors (e.g. labels) that are associated with the bounding boxes. + """ + if not keypoints.ndim == 2: + if keypoints.ndim < 2: + raise ValueError("Cannot sanitize a single Keypoint") + raise ValueError( + "Cannot sanitize KeyPoints structure that are not 2D. " + f"Expected shape to be (N, 2), got {keypoints.shape} ({keypoints.ndim=}, not 2)" + ) + if torch.jit.is_scripting() or is_pure_tensor(keypoints): + if canvas_size is None: + raise ValueError( + "canvas_size cannot be None if keypoints is a pure tensor. " + f"Got canvas_size={canvas_size}." + "Set that to appropriate values or pass keypoints as a tv_tensors.KeyPoints object." + ) + valid = _get_sanitize_keypoints_mask( + keypoints, + canvas_size=canvas_size, + ) + return keypoints[valid], valid + if not isinstance(keypoints, tv_tensors.KeyPoints): + raise ValueError("keypoints must be a tv_tensors.KeyPoints instance or a pure tensor.") + valid = _get_sanitize_keypoints_mask( + keypoints, + canvas_size=keypoints.canvas_size, + ) + return tv_tensors.wrap(keypoints[valid], like=keypoints), valid + + +def _get_sanitize_keypoints_mask( + keypoints: torch.Tensor, + canvas_size: Tuple[int, int], +) -> torch.Tensor: + image_h, image_w = canvas_size + x = keypoints[:, 0] + y = keypoints[:, 1] + + return (0 < x) & (x < image_w) & (0 < y) & (y < image_h) + + def sanitize_bounding_boxes( bounding_boxes: torch.Tensor, format: Optional[tv_tensors.BoundingBoxFormat] = None, diff --git a/torchvision/tv_tensors/__init__.py b/torchvision/tv_tensors/__init__.py index 1ba47f60a36..e1c6b2202df 100644 --- a/torchvision/tv_tensors/__init__.py +++ b/torchvision/tv_tensors/__init__.py @@ -1,18 +1,24 @@ +from typing import TypeVar + import torch from ._bounding_boxes import BoundingBoxes, BoundingBoxFormat from ._image import Image +from ._keypoints import KeyPoints from ._mask import Mask from ._torch_function_helpers import set_return_type from ._tv_tensor import TVTensor from ._video import Video +_WRAP_LIKE_T = TypeVar("_WRAP_LIKE_T", bound=TVTensor) + + # TODO: Fix this. We skip this method as it leads to # RecursionError: maximum recursion depth exceeded while calling a Python object # Until `disable` is removed, there will be graph breaks after all calls to functional transforms @torch.compiler.disable -def wrap(wrappee, *, like, **kwargs): +def wrap(wrappee: torch.Tensor, *, like: _WRAP_LIKE_T, **kwargs) -> _WRAP_LIKE_T: """Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.tv_tensors.TVTensor` subclass as ``like``. If ``like`` is a :class:`~torchvision.tv_tensors.BoundingBoxes`, the ``format`` and ``canvas_size`` of @@ -26,10 +32,25 @@ def wrap(wrappee, *, like, **kwargs): Ignored otherwise. """ if isinstance(like, BoundingBoxes): - return BoundingBoxes._wrap( + return BoundingBoxes._wrap( # type:ignore wrappee, format=kwargs.get("format", like.format), canvas_size=kwargs.get("canvas_size", like.canvas_size), ) + elif isinstance(like, KeyPoints): + return KeyPoints(wrappee, canvas_size=kwargs.get("canvas_size", like.canvas_size)) # type:ignore else: return wrappee.as_subclass(type(like)) + + +__all__: list[str] = [ + "wrap", + "KeyPoints", + "Video", + "TVTensor", + "set_return_type", + "Mask", + "Image", + "BoundingBoxFormat", + "BoundingBoxes", +] diff --git a/torchvision/tv_tensors/_keypoints.py b/torchvision/tv_tensors/_keypoints.py new file mode 100644 index 00000000000..e00c58d5134 --- /dev/null +++ b/torchvision/tv_tensors/_keypoints.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from typing import Any, Mapping, MutableSequence, Optional, Sequence, Tuple, TYPE_CHECKING, Union + +import torch +from torch.utils._pytree import tree_flatten + +from ._tv_tensor import TVTensor + + +class KeyPoints(TVTensor): + """:class:`torch.Tensor` subclass for tensors with shape ``[..., 2]`` that represent points in an image. + + Each point is represented by its XY coordinates. + + KeyPoints can be converted from :class:`torchvision.tv_tensors.BoundingBoxes` + by :func:`torchvision.transforms.v2.functional.convert_box_to_points`. + + Args: + data: Any data that can be turned into a tensor with :func:`torch.as_tensor`. + canvas_size (two-tuple of ints): Height and width of the corresponding image or video. + dtype (torch.dtype, optional): Desired data type of the bounding box. If omitted, will be inferred from + ``data``. + device (torch.device, optional): Desired device of the bounding box. If omitted and ``data`` is a + :class:`torch.Tensor`, the device is taken from it. Otherwise, the bounding box is constructed on the CPU. + requires_grad (bool, optional): Whether autograd should record operations on the bounding box. If omitted and + ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``. + """ + + canvas_size: Tuple[int, int] + + def __new__( + cls, + data: Any, + *, + canvas_size: Tuple[int, int], + dtype: Optional[torch.dtype] = None, + device: Optional[Union[torch.device, str, int]] = None, + requires_grad: Optional[bool] = None, + ): + tensor: torch.Tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) + if tensor.ndim == 1: + tensor = tensor.unsqueeze(0) + elif tensor.shape[-1] != 2: + raise ValueError(f"Expected a tensor of shape (..., 2), not {tensor.shape}") + points = tensor.as_subclass(cls) + points.canvas_size = canvas_size + return points + + if TYPE_CHECKING: + # EVIL: Just so that MYPY+PYLANCE+others stop shouting that everything is wrong when initializeing the TVTensor + # Not read or defined at Runtime (only at linting time). + # TODO: BOUNDING BOXES needs something similar + def __init__( + self, + data: Any, + *, + canvas_size: Tuple[int, int], + dtype: Optional[torch.dtype] = None, + device: Optional[Union[torch.device, str, int]] = None, + requires_grad: Optional[bool] = None, + ): + ... + + @classmethod + def _wrap_output( + cls, + output: Any, + args: Sequence[Any] = (), + kwargs: Optional[Mapping[str, Any]] = None, + ) -> Any: + # Mostly copied over from the BoundingBoxes TVTensor, minor improvements. + # This copies over the metadata. + # For BoundingBoxes, that included format, but we only support one format here ! + flat_params, _ = tree_flatten(args + (tuple(kwargs.values()) if kwargs else ())) # type: ignore[operator] + first_bbox_from_args = next(x for x in flat_params if isinstance(x, KeyPoints)) + canvas_size: Tuple[int, int] = first_bbox_from_args.canvas_size + + if isinstance(output, torch.Tensor) and not isinstance(output, KeyPoints): + output = KeyPoints(output, canvas_size=canvas_size) + elif isinstance(output, MutableSequence): + # For lists and list-like object we don't try to create a new object, we just set the values in the list + # This allows us to conserve the type of complex list-like object that may not follow the initialization API of lists + for i, part in enumerate(output): + output[i] = KeyPoints(part, canvas_size=canvas_size) + elif isinstance(output, Sequence): + # Non-mutable sequences handled here (like tuples) + # Every sequence that is not a mutable sequence is a non-mutable sequence + # We have to use a tuple here, since we know its initialization api, unlike for `output` + output = tuple(KeyPoints(part, canvas_size=canvas_size) for part in output) + return output + + def __repr__(self, *, tensor_contents: Any = None) -> str: + return self._make_repr(canvas_size=self.canvas_size)