Skip to content
18 changes: 18 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ def sample_position(values, max_value):
h, w = [torch.randint(1, s, (num_boxes,)) for s in canvas_size]
y = sample_position(h, canvas_size[0])
x = sample_position(w, canvas_size[1])
r = -360 * torch.rand((num_boxes,)) + 180

if format is tv_tensors.BoundingBoxFormat.XYWH:
parts = (x, y, w, h)
Expand All @@ -435,6 +436,23 @@ def sample_position(values, max_value):
cx = x + w / 2
cy = y + h / 2
parts = (cx, cy, w, h)
elif format is tv_tensors.BoundingBoxFormat.XYWHR:
parts = (x, y, w, h, r)
elif format is tv_tensors.BoundingBoxFormat.CXCYWHR:
cx = x + w / 2
cy = y + h / 2
parts = (cx, cy, w, h, r)
elif format is tv_tensors.BoundingBoxFormat.XYXYXYXY:
r_rad = r * torch.pi / 180.0
cos, sin = torch.cos(r_rad), torch.sin(r_rad)
x1, y1 = x, y
x3 = x1 + w * cos
y3 = y1 - w * sin
x2 = x3 + h * sin
y2 = y3 + h * cos
x4 = x1 + h * sin
y4 = y1 + h * cos
parts = (x1, y1, x3, y3, x2, y2, x4, y4)
else:
raise ValueError(f"Format {format} is not supported")

Expand Down
57 changes: 55 additions & 2 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1339,8 +1339,61 @@ def test_bbox_xywh_cxcywh(self):
box_xywh = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xywh")
assert_equal(box_xywh, box_tensor)

@pytest.mark.parametrize("inv_infmt", ["xwyh", "cxwyh"])
@pytest.mark.parametrize("inv_outfmt", ["xwcx", "xhwcy"])
def test_bbox_xywhr_cxcywhr(self):
box_tensor = torch.tensor(
[
[0, 0, 100, 100, 0],
[0, 0, 0, 0, 0],
[10, 15, 20, 20, 0],
[23, 35, 70, 60, 0],
[4, 2, 4, 2, 0],
[5, 5, 4, 2, 90],
[8, 4, 4, 2, 180],
[7, 1, 4, 2, -90],
],
dtype=torch.float,
)

exp_cxcywhr = torch.tensor(
[
[50, 50, 100, 100, 0],
[0, 0, 0, 0, 0],
[20, 25, 20, 20, 0],
[58, 65, 70, 60, 0],
[6, 3, 4, 2, 0],
[6, 3, 4, 2, 90],
[6, 3, 4, 2, 180],
[6, 3, 4, 2, -90],
],
dtype=torch.float,
)

assert exp_cxcywhr.size() == torch.Size([8, 5])
box_cxcywhr = ops.box_convert(box_tensor, in_fmt="xywhr", out_fmt="cxcywhr")
torch.testing.assert_close(box_cxcywhr, exp_cxcywhr)

# Reverse conversion
box_xywhr = ops.box_convert(box_cxcywhr, in_fmt="cxcywhr", out_fmt="xywhr")
torch.testing.assert_close(box_xywhr, box_tensor)

def test_bbox_cxcywhr_to_xyxyxyxy(self):
box_tensor = torch.tensor([[5, 3, 4, 2, 90]], dtype=torch.float)
exp_xyxyxyxy = torch.tensor([[4, 5, 4, 1, 6, 1, 6, 5]], dtype=torch.float)

assert exp_xyxyxyxy.size() == torch.Size([1, 8])
box_xyxyxyxy = ops.box_convert(box_tensor, in_fmt="cxcywhr", out_fmt="xyxyxyxy")
torch.testing.assert_close(box_xyxyxyxy, exp_xyxyxyxy)

def test_bbox_xywhr_to_xyxyxyxy(self):
box_tensor = torch.tensor([[4, 5, 4, 2, 90]], dtype=torch.float)
exp_xyxyxyxy = torch.tensor([[4, 5, 4, 1, 6, 1, 6, 5]], dtype=torch.float)

assert exp_xyxyxyxy.size() == torch.Size([1, 8])
box_xyxyxyxy = ops.box_convert(box_tensor, in_fmt="xywhr", out_fmt="xyxyxyxy")
torch.testing.assert_close(box_xyxyxyxy, exp_xyxyxyxy)

@pytest.mark.parametrize("inv_infmt", ["xwyh", "cxwyh", "xwyhr", "cxwyhr", "xxxxyyyy"])
@pytest.mark.parametrize("inv_outfmt", ["xwcx", "xhwcy", "xwcxr", "xhwcyr", "xyxyxxyy"])
def test_bbox_invalid(self, inv_infmt, inv_outfmt):
box_tensor = torch.tensor(
[[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float
Expand Down
76 changes: 45 additions & 31 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@
from torchvision.transforms.v2.functional._utils import _get_kernel, _register_kernel_internal


# While we are working on adjusting transform functions
# for rotated and oriented bounding boxes formats,
# we limit the perimeter of tests to formats
# for which transform functions are already implemented.
# In the future, this global variable will be replaced with `list(tv_tensors.BoundingBoxFormat)`
# to support all available formats.
SUPPORTED_BOX_FORMATS = [tv_tensors.BoundingBoxFormat[x] for x in ["XYXY", "XYWH", "CXCYWH"]]
NEW_BOX_FORMATS = [tv_tensors.BoundingBoxFormat[x] for x in ["XYWHR", "CXCYWHR", "XYXYXYXY"]]

# turns all warnings into errors for this module
pytestmark = [pytest.mark.filterwarnings("error")]

Expand Down Expand Up @@ -626,7 +635,7 @@ def test_kernel_image(self, size, interpolation, use_max_size, antialias, dtype,
check_scripted_vs_eager=not isinstance(size, int),
)

@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("size", OUTPUT_SIZES)
@pytest.mark.parametrize("use_max_size", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
Expand Down Expand Up @@ -757,7 +766,7 @@ def _reference_resize_bounding_boxes(self, bounding_boxes, *, size, max_size=Non
new_canvas_size=(new_height, new_width),
)

@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("size", OUTPUT_SIZES)
@pytest.mark.parametrize("use_max_size", [True, False])
@pytest.mark.parametrize("fn", [F.resize, transform_cls_to_functional(transforms.Resize)])
Expand Down Expand Up @@ -1003,7 +1012,7 @@ class TestHorizontalFlip:
def test_kernel_image(self, dtype, device):
check_kernel(F.horizontal_flip_image, make_image(dtype=dtype, device=device))

@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_bounding_boxes(self, format, dtype, device):
Expand Down Expand Up @@ -1072,7 +1081,7 @@ def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes):

return reference_affine_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix)

@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize(
"fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)]
)
Expand Down Expand Up @@ -1169,7 +1178,7 @@ def test_kernel_image(self, param, value, dtype, device):
shear=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["shear"],
center=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"],
)
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_bounding_boxes(self, param, value, format, dtype, device):
Expand Down Expand Up @@ -1318,7 +1327,7 @@ def _reference_affine_bounding_boxes(self, bounding_boxes, *, angle, translate,
),
)

@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
@pytest.mark.parametrize("translate", _CORRECTNESS_AFFINE_KWARGS["translate"])
@pytest.mark.parametrize("scale", _CORRECTNESS_AFFINE_KWARGS["scale"])
Expand Down Expand Up @@ -1346,7 +1355,7 @@ def test_functional_bounding_boxes_correctness(self, format, angle, translate, s

torch.testing.assert_close(actual, expected)

@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
@pytest.mark.parametrize("seed", list(range(5)))
def test_transform_bounding_boxes_correctness(self, format, center, seed):
Expand Down Expand Up @@ -1453,7 +1462,7 @@ class TestVerticalFlip:
def test_kernel_image(self, dtype, device):
check_kernel(F.vertical_flip_image, make_image(dtype=dtype, device=device))

@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_bounding_boxes(self, format, dtype, device):
Expand Down Expand Up @@ -1520,7 +1529,7 @@ def _reference_vertical_flip_bounding_boxes(self, bounding_boxes):

return reference_affine_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix)

@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
def test_bounding_boxes_correctness(self, format, fn):
bounding_boxes = make_bounding_boxes(format=format)
Expand Down Expand Up @@ -1589,7 +1598,7 @@ def test_kernel_image(self, param, value, dtype, device):
expand=[False, True],
center=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"],
)
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_bounding_boxes(self, param, value, format, dtype, device):
Expand Down Expand Up @@ -1760,7 +1769,7 @@ def _reference_rotate_bounding_boxes(self, bounding_boxes, *, angle, expand, cen
bounding_boxes
)

@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
@pytest.mark.parametrize("expand", [False, True])
@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
Expand All @@ -1773,7 +1782,7 @@ def test_functional_bounding_boxes_correctness(self, format, angle, expand, cent
torch.testing.assert_close(actual, expected)
torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0)

@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("expand", [False, True])
@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
@pytest.mark.parametrize("seed", list(range(5)))
Expand Down Expand Up @@ -2694,7 +2703,7 @@ def test_kernel_image(self, param, value, dtype, device):
check_cuda_vs_cpu=dtype is not torch.float16,
)

@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_bounding_boxes(self, format, dtype, device):
Expand Down Expand Up @@ -2821,7 +2830,7 @@ def test_kernel_image(self, kwargs, dtype, device):
check_kernel(F.crop_image, make_image(self.INPUT_SIZE, dtype=dtype, device=device), **kwargs)

@pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS)
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_bounding_box(self, kwargs, format, dtype, device):
Expand Down Expand Up @@ -2971,7 +2980,7 @@ def _reference_crop_bounding_boxes(self, bounding_boxes, *, top, left, height, w
)

@pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS)
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_functional_bounding_box_correctness(self, kwargs, format, dtype, device):
Expand All @@ -2984,7 +2993,7 @@ def test_functional_bounding_box_correctness(self, kwargs, format, dtype, device
assert_equal(F.get_size(actual), F.get_size(expected))

@pytest.mark.parametrize("output_size", [(17, 11), (11, 17), (11, 11)])
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("seed", list(range(5)))
Expand Down Expand Up @@ -3507,7 +3516,8 @@ def test_aug_mix_severity_error(self, severity):


class TestConvertBoundingBoxFormat:
old_new_formats = list(itertools.permutations(iter(tv_tensors.BoundingBoxFormat), 2))
old_new_formats = list(itertools.permutations(SUPPORTED_BOX_FORMATS, 2))
old_new_formats += list(itertools.permutations(NEW_BOX_FORMATS, 2))
Comment on lines +3519 to +3520
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that eventually, we'll want this to become old_new_formats = list(itertools.permutations(SUPPORTED_BOX_FORMATS + NEW_BOX_FORMATS, 2))

i.e. we'll probably also want to enable the kinf of rotated format <--> non-rotated format conversions. But this can come later and for this PR, we can keep things as-is.


@pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
def test_kernel(self, old_format, new_format):
Expand All @@ -3518,7 +3528,7 @@ def test_kernel(self, old_format, new_format):
old_format=old_format,
)

@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("inplace", [False, True])
def test_kernel_noop(self, format, inplace):
input = make_bounding_boxes(format=format).as_subclass(torch.Tensor)
Expand All @@ -3542,9 +3552,13 @@ def test_kernel_inplace(self, old_format, new_format):
output_inplace = F.convert_bounding_box_format(
input, old_format=old_format, new_format=new_format, inplace=True
)
assert output_inplace.data_ptr() == input.data_ptr()
assert output_inplace._version > input_version
assert output_inplace is input
if old_format != tv_tensors.BoundingBoxFormat.XYXYXYXY and new_format != tv_tensors.BoundingBoxFormat.XYXYXYXY:
# NOTE: BoundingBox format conversion from and to XYXYXYXY format
# cannot modify the input tensor inplace as it requires a dimension
# change.
assert output_inplace.data_ptr() == input.data_ptr()
assert output_inplace._version > input_version
assert output_inplace is input

assert_equal(output_inplace, output_out_of_place)

Expand All @@ -3563,7 +3577,7 @@ def test_transform(self, old_format, new_format, format_type):
@pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
def test_strings(self, old_format, new_format):
# Non-regression test for https://github.com/pytorch/vision/issues/8258
input = tv_tensors.BoundingBoxes(torch.tensor([[10, 10, 20, 20]]), format=old_format, canvas_size=(50, 50))
input = make_bounding_boxes(format=old_format, canvas_size=(50, 50))
expected = self._reference_convert_bounding_box_format(input, new_format)

old_format = old_format.name
Expand Down Expand Up @@ -3728,7 +3742,7 @@ def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, h
new_canvas_size=size,
)

@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
def test_functional_bounding_boxes_correctness(self, format):
bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format)

Expand Down Expand Up @@ -3796,7 +3810,7 @@ def test_kernel_image(self, param, value, dtype, device):
),
)

@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
def test_kernel_bounding_boxes(self, format):
bounding_boxes = make_bounding_boxes(format=format)
check_kernel(
Expand Down Expand Up @@ -3915,7 +3929,7 @@ def _reference_pad_bounding_boxes(self, bounding_boxes, *, padding):
)

@pytest.mark.parametrize("padding", CORRECTNESS_PADDINGS)
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("fn", [F.pad, transform_cls_to_functional(transforms.Pad)])
Expand Down Expand Up @@ -3944,7 +3958,7 @@ def test_kernel_image(self, output_size, dtype, device):
)

@pytest.mark.parametrize("output_size", OUTPUT_SIZES)
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
def test_kernel_bounding_boxes(self, output_size, format):
bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format)
check_kernel(
Expand Down Expand Up @@ -4023,7 +4037,7 @@ def _reference_center_crop_bounding_boxes(self, bounding_boxes, output_size):
)

@pytest.mark.parametrize("output_size", OUTPUT_SIZES)
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("fn", [F.center_crop, transform_cls_to_functional(transforms.CenterCrop)])
Expand Down Expand Up @@ -4090,7 +4104,7 @@ def test_kernel_image_error(self):
coefficients=COEFFICIENTS,
start_end_points=START_END_POINTS,
)
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
def test_kernel_bounding_boxes(self, param, value, format):
if param == "start_end_points":
kwargs = dict(zip(["startpoints", "endpoints"], value))
Expand Down Expand Up @@ -4266,7 +4280,7 @@ def perspective_bounding_boxes(bounding_boxes):
)

@pytest.mark.parametrize(("startpoints", "endpoints"), START_END_POINTS)
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_correctness_perspective_bounding_boxes(self, startpoints, endpoints, format, dtype, device):
Expand Down Expand Up @@ -4473,7 +4487,7 @@ def test_correctness_image(self, mean, std, dtype, fn):


class TestClampBoundingBoxes:
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel(self, format, dtype, device):
Expand All @@ -4485,7 +4499,7 @@ def test_kernel(self, format, dtype, device):
canvas_size=bounding_boxes.canvas_size,
)

@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
def test_functional(self, format):
check_functional(F.clamp_bounding_boxes, make_bounding_boxes(format=format))

Expand Down
Loading
Loading