Skip to content

Batched Box Ops #9058

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,15 +1073,15 @@ def test_forward(self, device, contiguous, batch_sz, dtype=None):
expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation)

torch.testing.assert_close(
res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres:\n{res}\nexpected:\n{expected}"
res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres: \n{res}\nexpected: \n{expected}"
)

# no modulation test
res = layer(x, offset)
expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation)

torch.testing.assert_close(
res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres:\n{res}\nexpected:\n{expected}"
res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres: \n{res}\nexpected: \n{expected}"
)

def test_wrong_sizes(self):
Expand Down Expand Up @@ -1468,7 +1468,7 @@ def test_box_area_jit(self):
]


def gen_box(size, dtype=torch.float):
def gen_box(size, dtype=torch.float) -> Tensor:
xy1 = torch.rand((size, 2), dtype=dtype)
xy2 = xy1 + torch.rand((size, 2), dtype=dtype)
return torch.cat([xy1, xy2], axis=-1)
Expand Down Expand Up @@ -1510,6 +1510,14 @@ def _run_cartesian_test(target_fn: Callable):
b = target_fn(boxes1, boxes2)
torch.testing.assert_close(a, b)

@staticmethod
def _run_batch_test(target_fn: Callable):
boxes1 = torch.stack([gen_box(5) for _ in range(3)], dim=0)
boxes2 = torch.stack([gen_box(5) for _ in range(3)], dim=0)
native: Tensor = target_fn(boxes1, boxes2)
iterative: Tensor = torch.stack([target_fn(*pairs) for pairs in zip(boxes1, boxes2)], dim=0)
torch.testing.assert_close(native, iterative)


class TestBoxIou(TestIouBase):
int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.0625, 0.25, 0.0]]
Expand All @@ -1532,6 +1540,9 @@ def test_iou_jit(self):
def test_iou_cartesian(self):
self._run_cartesian_test(ops.box_iou)

def test_iou_batch(self):
self._run_batch_test(ops.box_iou)


class TestGeneralizedBoxIou(TestIouBase):
int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0], [0.0625, 0.25, -0.8819]]
Expand All @@ -1554,6 +1565,9 @@ def test_iou_jit(self):
def test_iou_cartesian(self):
self._run_cartesian_test(ops.generalized_box_iou)

def test_iou_batch(self):
self._run_batch_test(ops.generalized_box_iou)


class TestDistanceBoxIoU(TestIouBase):
int_expected = [
Expand Down Expand Up @@ -1581,6 +1595,9 @@ def test_iou_jit(self):
def test_iou_cartesian(self):
self._run_cartesian_test(ops.distance_box_iou)

def test_iou_batch(self):
self._run_batch_test(ops.distance_box_iou)


class TestCompleteBoxIou(TestIouBase):
int_expected = [
Expand Down Expand Up @@ -1608,6 +1625,9 @@ def test_iou_jit(self):
def test_iou_cartesian(self):
self._run_cartesian_test(ops.complete_box_iou)

def test_iou_batch(self):
self._run_batch_test(ops.complete_box_iou)


def get_boxes(dtype, device):
box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device)
Expand Down
77 changes: 40 additions & 37 deletions torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
the transform :func:`~torchvision.transforms.v2.SanitizeBoundingBoxes` instead.

Args:
boxes (Tensor[N, 4]): boxes in ``(x1, y1, x2, y2)`` format
boxes (Tensor[..., 4]): boxes in ``(x1, y1, x2, y2)`` format
with ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
min_size (float): minimum size

Expand All @@ -140,7 +140,7 @@ def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(remove_small_boxes)
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
ws, hs = boxes[..., 2] - boxes[..., 0], boxes[..., 3] - boxes[..., 1]
keep = (ws >= min_size) & (hs >= min_size)
keep = torch.where(keep)[0]
return keep
Expand All @@ -155,12 +155,12 @@ def clip_boxes_to_image(boxes: Tensor, size: tuple[int, int]) -> Tensor:
the transform :func:`~torchvision.transforms.v2.ClampBoundingBoxes` instead.

Args:
boxes (Tensor[N, 4]): boxes in ``(x1, y1, x2, y2)`` format
boxes (Tensor[..., 4]): boxes in ``(x1, y1, x2, y2)`` format
with ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
size (Tuple[height, width]): size of the image

Returns:
Tensor[N, 4]: clipped boxes
Tensor[..., 4]: clipped boxes
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(clip_boxes_to_image)
Expand Down Expand Up @@ -276,7 +276,7 @@ def box_area(boxes: Tensor) -> Tensor:
(x1, y1, x2, y2) coordinates.

Args:
boxes (Tensor[N, 4]): boxes for which the area will be computed. They
boxes (Tensor[..., 4]): boxes for which the area will be computed. They
are expected to be in (x1, y1, x2, y2) format with
``0 <= x1 < x2`` and ``0 <= y1 < y2``.

Expand All @@ -286,7 +286,7 @@ def box_area(boxes: Tensor) -> Tensor:
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(box_area)
boxes = _upcast(boxes)
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
return (boxes[..., 2] - boxes[..., 0]) * (boxes[..., 3] - boxes[..., 1])


# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
Expand All @@ -295,13 +295,13 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Tensor]:
area1 = box_area(boxes1)
area2 = box_area(boxes2)

lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [...,N,M,2]
rb = torch.min(boxes1[..., None, 2:], boxes2[..., None, :, 2:]) # [...,N,M,2]

wh = _upcast(rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
inter = wh[..., 0] * wh[..., 1] # [N,M]

union = area1[:, None] + area2 - inter
union = area1[..., None] + area2[..., None, :] - inter

return inter, union

Expand All @@ -314,11 +314,12 @@ def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
``0 <= x1 < x2`` and ``0 <= y1 < y2``.

Args:
boxes1 (Tensor[N, 4]): first set of boxes
boxes2 (Tensor[M, 4]): second set of boxes
boxes1 (Tensor[..., N, 4]): first set of boxes
boxes2 (Tensor[..., M, 4]): second set of boxes

Returns:
Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
Tensor[..., N, M]: the NxM matrix containing the pairwise IoU values for every element
in boxes1 and boxes2
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(box_iou)
Expand All @@ -336,11 +337,11 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
``0 <= x1 < x2`` and ``0 <= y1 < y2``.

Args:
boxes1 (Tensor[N, 4]): first set of boxes
boxes2 (Tensor[M, 4]): second set of boxes
boxes1 (Tensor[..., N, 4]): first set of boxes
boxes2 (Tensor[..., M, 4]): second set of boxes

Returns:
Tensor[N, M]: the NxM matrix containing the pairwise generalized IoU values
Tensor[..., N, M]: the NxM matrix containing the pairwise generalized IoU values
for every element in boxes1 and boxes2
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
Expand All @@ -349,11 +350,11 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
inter, union = _box_inter_union(boxes1, boxes2)
iou = inter / union

lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
lti = torch.min(boxes1[..., None, :2], boxes2[..., None, :, :2])
rbi = torch.max(boxes1[..., None, 2:], boxes2[..., None, :, 2:])

whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2]
areai = whi[:, :, 0] * whi[:, :, 1]
areai = whi[..., 0] * whi[..., 1]

return iou - (areai - union) / areai

Expand All @@ -364,11 +365,11 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
Args:
boxes1 (Tensor[N, 4]): first set of boxes
boxes2 (Tensor[M, 4]): second set of boxes
boxes1 (Tensor[..., N, 4]): first set of boxes
boxes2 (Tensor[..., M, 4]): second set of boxes
eps (float, optional): small number to prevent division by zero. Default: 1e-7
Returns:
Tensor[N, M]: the NxM matrix containing the pairwise complete IoU values
Tensor[..., N, M]: the NxM matrix containing the pairwise complete IoU values
for every element in boxes1 and boxes2
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
Expand All @@ -379,11 +380,11 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso

diou, iou = _box_diou_iou(boxes1, boxes2, eps)

w_pred = boxes1[:, None, 2] - boxes1[:, None, 0]
h_pred = boxes1[:, None, 3] - boxes1[:, None, 1]
w_pred = boxes1[..., None, 2] - boxes1[..., None, 0]
h_pred = boxes1[..., None, 3] - boxes1[..., None, 1]

w_gt = boxes2[:, 2] - boxes2[:, 0]
h_gt = boxes2[:, 3] - boxes2[:, 1]
w_gt = boxes2[..., None, :, 2] - boxes2[..., None, :, 0]
h_gt = boxes2[..., None, :, 3] - boxes2[..., None, :, 1]

v = (4 / (torch.pi**2)) * torch.pow(torch.atan(w_pred / h_pred) - torch.atan(w_gt / h_gt), 2)
with torch.no_grad():
Expand All @@ -399,12 +400,12 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
``0 <= x1 < x2`` and ``0 <= y1 < y2``.

Args:
boxes1 (Tensor[N, 4]): first set of boxes
boxes2 (Tensor[M, 4]): second set of boxes
boxes1 (Tensor[..., N, 4]): first set of boxes
boxes2 (Tensor[..., M, 4]): second set of boxes
eps (float, optional): small number to prevent division by zero. Default: 1e-7

Returns:
Tensor[N, M]: the NxM matrix containing the pairwise distance IoU values
Tensor[..., N, M]: the NxM matrix containing the pairwise distance IoU values
for every element in boxes1 and boxes2
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
Expand All @@ -419,17 +420,19 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> tuple[Tensor, Tensor]:

iou = box_iou(boxes1, boxes2)
lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
lti = torch.min(boxes1[..., None, :2], boxes2[..., None, :, :2])
rbi = torch.max(boxes1[..., None, 2:], boxes2[..., None, :, 2:])
whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2]
diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps
diagonal_distance_squared = (whi[..., 0] ** 2) + (whi[..., 1] ** 2) + eps
# centers of boxes
x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2
y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2
x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2
y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2
x_p = (boxes1[..., 0] + boxes1[..., 2]) / 2
y_p = (boxes1[..., 1] + boxes1[..., 3]) / 2
x_g = (boxes2[..., 0] + boxes2[..., 2]) / 2
y_g = (boxes2[..., 1] + boxes2[..., 3]) / 2
# The distance between boxes' centers squared.
centers_distance_squared = (_upcast(x_p[:, None] - x_g[None, :]) ** 2) + (_upcast(y_p[:, None] - y_g[None, :]) ** 2)
centers_distance_squared = (_upcast(x_p[..., None] - x_g[..., None, :]) ** 2) + (
_upcast(y_p[..., None] - y_g[..., None, :]) ** 2
)
# The distance IoU is the IoU penalized by a normalized
# distance between boxes' centers squared.
return iou - (centers_distance_squared / diagonal_distance_squared), iou
Expand Down