From 0eba325ed2de25cf283029ad6fcc34d286980715 Mon Sep 17 00:00:00 2001 From: Bryce Ferenczi Date: Sun, 4 May 2025 12:18:13 +1000 Subject: [PATCH] Make many box ops batch-dim compatible. Add test for batched calculations. Signed-off-by: Bryce Ferenczi --- test/test_ops.py | 26 ++++++++++++-- torchvision/ops/boxes.py | 77 +++++++++++++++++++++------------------- 2 files changed, 63 insertions(+), 40 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 3f0d8312c01..2e039282aa9 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1073,7 +1073,7 @@ def test_forward(self, device, contiguous, batch_sz, dtype=None): expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation) torch.testing.assert_close( - res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres:\n{res}\nexpected:\n{expected}" + res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres: \n{res}\nexpected: \n{expected}" ) # no modulation test @@ -1081,7 +1081,7 @@ def test_forward(self, device, contiguous, batch_sz, dtype=None): expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation) torch.testing.assert_close( - res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres:\n{res}\nexpected:\n{expected}" + res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres: \n{res}\nexpected: \n{expected}" ) def test_wrong_sizes(self): @@ -1468,7 +1468,7 @@ def test_box_area_jit(self): ] -def gen_box(size, dtype=torch.float): +def gen_box(size, dtype=torch.float) -> Tensor: xy1 = torch.rand((size, 2), dtype=dtype) xy2 = xy1 + torch.rand((size, 2), dtype=dtype) return torch.cat([xy1, xy2], axis=-1) @@ -1510,6 +1510,14 @@ def _run_cartesian_test(target_fn: Callable): b = target_fn(boxes1, boxes2) torch.testing.assert_close(a, b) + @staticmethod + def _run_batch_test(target_fn: Callable): + boxes1 = torch.stack([gen_box(5) for _ in range(3)], dim=0) + boxes2 = torch.stack([gen_box(5) for _ in range(3)], dim=0) + native: Tensor = target_fn(boxes1, boxes2) + iterative: Tensor = torch.stack([target_fn(*pairs) for pairs in zip(boxes1, boxes2)], dim=0) + torch.testing.assert_close(native, iterative) + class TestBoxIou(TestIouBase): int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.0625, 0.25, 0.0]] @@ -1532,6 +1540,9 @@ def test_iou_jit(self): def test_iou_cartesian(self): self._run_cartesian_test(ops.box_iou) + def test_iou_batch(self): + self._run_batch_test(ops.box_iou) + class TestGeneralizedBoxIou(TestIouBase): int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0], [0.0625, 0.25, -0.8819]] @@ -1554,6 +1565,9 @@ def test_iou_jit(self): def test_iou_cartesian(self): self._run_cartesian_test(ops.generalized_box_iou) + def test_iou_batch(self): + self._run_batch_test(ops.generalized_box_iou) + class TestDistanceBoxIoU(TestIouBase): int_expected = [ @@ -1581,6 +1595,9 @@ def test_iou_jit(self): def test_iou_cartesian(self): self._run_cartesian_test(ops.distance_box_iou) + def test_iou_batch(self): + self._run_batch_test(ops.distance_box_iou) + class TestCompleteBoxIou(TestIouBase): int_expected = [ @@ -1608,6 +1625,9 @@ def test_iou_jit(self): def test_iou_cartesian(self): self._run_cartesian_test(ops.complete_box_iou) + def test_iou_batch(self): + self._run_batch_test(ops.complete_box_iou) + def get_boxes(dtype, device): box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index be8d59716bf..6c633239e40 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -130,7 +130,7 @@ def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor: the transform :func:`~torchvision.transforms.v2.SanitizeBoundingBoxes` instead. Args: - boxes (Tensor[N, 4]): boxes in ``(x1, y1, x2, y2)`` format + boxes (Tensor[..., 4]): boxes in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and ``0 <= y1 < y2``. min_size (float): minimum size @@ -140,7 +140,7 @@ def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor: """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(remove_small_boxes) - ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] + ws, hs = boxes[..., 2] - boxes[..., 0], boxes[..., 3] - boxes[..., 1] keep = (ws >= min_size) & (hs >= min_size) keep = torch.where(keep)[0] return keep @@ -155,12 +155,12 @@ def clip_boxes_to_image(boxes: Tensor, size: tuple[int, int]) -> Tensor: the transform :func:`~torchvision.transforms.v2.ClampBoundingBoxes` instead. Args: - boxes (Tensor[N, 4]): boxes in ``(x1, y1, x2, y2)`` format + boxes (Tensor[..., 4]): boxes in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and ``0 <= y1 < y2``. size (Tuple[height, width]): size of the image Returns: - Tensor[N, 4]: clipped boxes + Tensor[..., 4]: clipped boxes """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(clip_boxes_to_image) @@ -276,7 +276,7 @@ def box_area(boxes: Tensor) -> Tensor: (x1, y1, x2, y2) coordinates. Args: - boxes (Tensor[N, 4]): boxes for which the area will be computed. They + boxes (Tensor[..., 4]): boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with ``0 <= x1 < x2`` and ``0 <= y1 < y2``. @@ -286,7 +286,7 @@ def box_area(boxes: Tensor) -> Tensor: if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(box_area) boxes = _upcast(boxes) - return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + return (boxes[..., 2] - boxes[..., 0]) * (boxes[..., 3] - boxes[..., 1]) # implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py @@ -295,13 +295,13 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Tensor]: area1 = box_area(boxes1) area2 = box_area(boxes2) - lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] - rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [...,N,M,2] + rb = torch.min(boxes1[..., None, 2:], boxes2[..., None, :, 2:]) # [...,N,M,2] wh = _upcast(rb - lt).clamp(min=0) # [N,M,2] - inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + inter = wh[..., 0] * wh[..., 1] # [N,M] - union = area1[:, None] + area2 - inter + union = area1[..., None] + area2[..., None, :] - inter return inter, union @@ -314,11 +314,12 @@ def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: ``0 <= x1 < x2`` and ``0 <= y1 < y2``. Args: - boxes1 (Tensor[N, 4]): first set of boxes - boxes2 (Tensor[M, 4]): second set of boxes + boxes1 (Tensor[..., N, 4]): first set of boxes + boxes2 (Tensor[..., M, 4]): second set of boxes Returns: - Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2 + Tensor[..., N, M]: the NxM matrix containing the pairwise IoU values for every element + in boxes1 and boxes2 """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(box_iou) @@ -336,11 +337,11 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: ``0 <= x1 < x2`` and ``0 <= y1 < y2``. Args: - boxes1 (Tensor[N, 4]): first set of boxes - boxes2 (Tensor[M, 4]): second set of boxes + boxes1 (Tensor[..., N, 4]): first set of boxes + boxes2 (Tensor[..., M, 4]): second set of boxes Returns: - Tensor[N, M]: the NxM matrix containing the pairwise generalized IoU values + Tensor[..., N, M]: the NxM matrix containing the pairwise generalized IoU values for every element in boxes1 and boxes2 """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): @@ -349,11 +350,11 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: inter, union = _box_inter_union(boxes1, boxes2) iou = inter / union - lti = torch.min(boxes1[:, None, :2], boxes2[:, :2]) - rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + lti = torch.min(boxes1[..., None, :2], boxes2[..., None, :, :2]) + rbi = torch.max(boxes1[..., None, 2:], boxes2[..., None, :, 2:]) whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2] - areai = whi[:, :, 0] * whi[:, :, 1] + areai = whi[..., 0] * whi[..., 1] return iou - (areai - union) / areai @@ -364,11 +365,11 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and ``0 <= y1 < y2``. Args: - boxes1 (Tensor[N, 4]): first set of boxes - boxes2 (Tensor[M, 4]): second set of boxes + boxes1 (Tensor[..., N, 4]): first set of boxes + boxes2 (Tensor[..., M, 4]): second set of boxes eps (float, optional): small number to prevent division by zero. Default: 1e-7 Returns: - Tensor[N, M]: the NxM matrix containing the pairwise complete IoU values + Tensor[..., N, M]: the NxM matrix containing the pairwise complete IoU values for every element in boxes1 and boxes2 """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): @@ -379,11 +380,11 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso diou, iou = _box_diou_iou(boxes1, boxes2, eps) - w_pred = boxes1[:, None, 2] - boxes1[:, None, 0] - h_pred = boxes1[:, None, 3] - boxes1[:, None, 1] + w_pred = boxes1[..., None, 2] - boxes1[..., None, 0] + h_pred = boxes1[..., None, 3] - boxes1[..., None, 1] - w_gt = boxes2[:, 2] - boxes2[:, 0] - h_gt = boxes2[:, 3] - boxes2[:, 1] + w_gt = boxes2[..., None, :, 2] - boxes2[..., None, :, 0] + h_gt = boxes2[..., None, :, 3] - boxes2[..., None, :, 1] v = (4 / (torch.pi**2)) * torch.pow(torch.atan(w_pred / h_pred) - torch.atan(w_gt / h_gt), 2) with torch.no_grad(): @@ -399,12 +400,12 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso ``0 <= x1 < x2`` and ``0 <= y1 < y2``. Args: - boxes1 (Tensor[N, 4]): first set of boxes - boxes2 (Tensor[M, 4]): second set of boxes + boxes1 (Tensor[..., N, 4]): first set of boxes + boxes2 (Tensor[..., M, 4]): second set of boxes eps (float, optional): small number to prevent division by zero. Default: 1e-7 Returns: - Tensor[N, M]: the NxM matrix containing the pairwise distance IoU values + Tensor[..., N, M]: the NxM matrix containing the pairwise distance IoU values for every element in boxes1 and boxes2 """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): @@ -419,17 +420,19 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> tuple[Tensor, Tensor]: iou = box_iou(boxes1, boxes2) - lti = torch.min(boxes1[:, None, :2], boxes2[:, :2]) - rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + lti = torch.min(boxes1[..., None, :2], boxes2[..., None, :, :2]) + rbi = torch.max(boxes1[..., None, 2:], boxes2[..., None, :, 2:]) whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2] - diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps + diagonal_distance_squared = (whi[..., 0] ** 2) + (whi[..., 1] ** 2) + eps # centers of boxes - x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2 - y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2 - x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2 - y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2 + x_p = (boxes1[..., 0] + boxes1[..., 2]) / 2 + y_p = (boxes1[..., 1] + boxes1[..., 3]) / 2 + x_g = (boxes2[..., 0] + boxes2[..., 2]) / 2 + y_g = (boxes2[..., 1] + boxes2[..., 3]) / 2 # The distance between boxes' centers squared. - centers_distance_squared = (_upcast(x_p[:, None] - x_g[None, :]) ** 2) + (_upcast(y_p[:, None] - y_g[None, :]) ** 2) + centers_distance_squared = (_upcast(x_p[..., None] - x_g[..., None, :]) ** 2) + ( + _upcast(y_p[..., None] - y_g[..., None, :]) ** 2 + ) # The distance IoU is the IoU penalized by a normalized # distance between boxes' centers squared. return iou - (centers_distance_squared / diagonal_distance_squared), iou