Skip to content
This repository was archived by the owner on Mar 12, 2024. It is now read-only.

Commit eff320f

Browse files
alcinosdai20242024
authored andcommitted
Reduce HungarianMatcher's space complexity.
The memory reduction factor of the cost matrix is sum(#target objects) / max(#target objects). That is achieved by no longer computing and storing matching costs between predictions and targets at different positions inside the batch. More exactly the original matrix of shape [batch_size * queries, sum(#target objects)] is shrinked to a tensor of shape [batch_size, queries, max(#target objects)]. Besides allowing much larger batch sizes, tested on the table structure recognition task using the Table Transformer (TATR) (125 queries, 7 classes) with pubmed data, this change also results a) on CUDA at all batch sizes and on CPU with small batchs in a small but meaningful speedup, b) on CPU with larger batch sizes in much higher speedups. The processing time decrease computed as (1 - new_time / old_time) is shown below in various configuration: Batch | Device size | cuda cpu ------------------ 1 8.2% 1.6% 2 1.6% 9.3% 3 1.6% 7.7% 4 0.9% 11.2% 5 0.8% 13.9% 6 0.9% 15.5% 7 0.9% 23.1% 8 47.1% 16 70.6% 32 88.3% 64 95.0%
1 parent 3af9fa8 commit eff320f

File tree

3 files changed

+138
-42
lines changed

3 files changed

+138
-42
lines changed

models/matcher.py

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
from scipy.optimize import linear_sum_assignment
77
from torch import nn
8+
from torch.nn.utils.rnn import pad_sequence
89

910
from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
1011

@@ -52,34 +53,66 @@ def forward(self, outputs, targets):
5253
For each batch element, it holds:
5354
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
5455
"""
55-
bs, num_queries = outputs["pred_logits"].shape[:2]
56-
57-
# We flatten to compute the cost matrices in a batch
58-
out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
59-
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
60-
61-
# Also concat the target labels and boxes
62-
tgt_ids = torch.cat([v["labels"] for v in targets])
63-
tgt_bbox = torch.cat([v["boxes"] for v in targets])
64-
65-
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
66-
# but approximate it in 1 - proba[target class].
67-
# The 1 is a constant that doesn't change the matching, it can be ommitted.
68-
cost_class = -out_prob[:, tgt_ids]
56+
# In the comments below:
57+
# - `bs` is the batch size, i.e. outputs["pred_logits"].shape[0];
58+
# - `mo` is the maximum number of objects over all the targets,
59+
# i.e. `max((len(v["labels"]) for v in targets))`;
60+
# - `q` is the number of queries, i.e. outputs["pred_logits"].shape[1];
61+
# - `cl` is the number of classes including no-object,
62+
# i.e. outputs["pred_logits"].shape[2] or self.num_classes + 1.
63+
if len(targets) == 1:
64+
# This branch is just an optimization, not needed for correctness.
65+
tgt_ids = targets[0]["labels"].unsqueeze(dim=0)
66+
tgt_bbox = targets[0]["boxes"].unsqueeze(dim=0)
67+
else:
68+
tgt_ids = pad_sequence(
69+
[target["labels"] for target in targets],
70+
batch_first=True,
71+
padding_value=0
72+
) # (bs, mo)
73+
tgt_bbox = pad_sequence(
74+
[target["boxes"] for target in targets],
75+
batch_first=True,
76+
padding_value=0
77+
) # (bs, mo, 4)
78+
79+
out_bbox = outputs["pred_boxes"] # (bs, q, 4)
6980

7081
# Compute the L1 cost between boxes
71-
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
82+
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) # (bs, q, mo)
7283

7384
# Compute the giou cost betwen boxes
74-
cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
85+
out_bbox_xyxy = box_cxcywh_to_xyxy(out_bbox)
86+
tgt_bbox_xyxy = box_cxcywh_to_xyxy(tgt_bbox)
87+
giou = generalized_box_iou(
88+
out_bbox_xyxy, tgt_bbox_xyxy) # (bs, q, mo)
89+
90+
# Compute the classification cost. Contrary to the loss, we don't use
91+
# the Negative Log Likelihood, but approximate it
92+
# in `1 - proba[target class]`. The 1 is a constant that does not
93+
# change the matching, it can be ommitted.
94+
out_prob = outputs["pred_logits"].softmax(-1) # (bs, q, c)
95+
prob_class = torch.gather(
96+
out_prob,
97+
dim=2,
98+
index=tgt_ids.unsqueeze(dim=1).expand(-1, out_prob.shape[1], -1)
99+
) # (bs, q, mo)
75100

76101
# Final cost matrix
77-
C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
78-
C = C.view(bs, num_queries, -1).cpu()
79-
80-
sizes = [len(v["boxes"]) for v in targets]
81-
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
82-
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
102+
C = self.cost_bbox * cost_bbox - self.cost_giou * giou - self.cost_class * prob_class
103+
c = C.cpu()
104+
105+
indices = [
106+
linear_sum_assignment(c[i, :, :len(v["labels"])])
107+
for i, v in enumerate(targets)
108+
]
109+
return [
110+
(
111+
torch.as_tensor(i, dtype=torch.int64),
112+
torch.as_tensor(j, dtype=torch.int64),
113+
)
114+
for i, j in indices
115+
]
83116

84117

85118
def build_matcher(args):

test_all.py

Lines changed: 69 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
22
import io
33
import unittest
4+
import functools
5+
import operator
6+
7+
from itertools import combinations_with_replacement
48

59
import torch
610
from torch import nn, Tensor
11+
from torchvision import ops
712
from typing import List
813

914
from models.matcher import HungarianMatcher
@@ -40,14 +45,21 @@ def test_hungarian(self):
4045
matcher = HungarianMatcher()
4146
targets = [{'labels': tgt_labels, 'boxes': tgt_boxes}]
4247
indices_single = matcher({'pred_logits': logits, 'pred_boxes': boxes}, targets)
43-
indices_batched = matcher({'pred_logits': logits.repeat(2, 1, 1),
44-
'pred_boxes': boxes.repeat(2, 1, 1)}, targets * 2)
48+
batch_size = 2
49+
indices_batched = matcher(
50+
{
51+
'pred_logits': logits.repeat(batch_size, 1, 1),
52+
'pred_boxes': boxes.repeat(batch_size, 1, 1),
53+
},
54+
targets * batch_size,
55+
)
4556
self.assertEqual(len(indices_single[0][0]), n_targets)
4657
self.assertEqual(len(indices_single[0][1]), n_targets)
47-
self.assertEqual(self.indices_torch2python(indices_single),
48-
self.indices_torch2python([indices_batched[0]]))
49-
self.assertEqual(self.indices_torch2python(indices_single),
50-
self.indices_torch2python([indices_batched[1]]))
58+
for i in range(batch_size):
59+
self.assertEqual(
60+
self.indices_torch2python(indices_single),
61+
self.indices_torch2python([indices_batched[i]]),
62+
)
5163

5264
# test with empty targets
5365
tgt_labels_empty = torch.randint(high=n_classes, size=(0,))
@@ -102,6 +114,57 @@ def test_model_detection_different_inputs(self):
102114
out = model([x])
103115
self.assertIn('pred_logits', out)
104116

117+
def test_box_iou_multiple_dimensions(self):
118+
for extra_dims in range(3):
119+
for extra_lengths in combinations_with_replacement(range(1, 4), extra_dims):
120+
p = functools.reduce(operator.mul, extra_lengths, 1)
121+
for n in range(3):
122+
a = torch.rand(extra_lengths + (n, 4))
123+
for m in range(3):
124+
b = torch.rand(extra_lengths + (m, 4))
125+
iou, union = box_ops.box_iou(a, b)
126+
self.assertTupleEqual(iou.shape, union.shape)
127+
self.assertTupleEqual(iou.shape, extra_lengths + (n, m))
128+
iou_it = iter(iou.view(p, n, m))
129+
for x, y in zip(a.view(p, n, 4), b.view(p, m, 4), strict=True):
130+
self.assertTrue(
131+
torch.equal(next(iou_it), ops.box_iou(x, y))
132+
)
133+
134+
def test_generalized_box_iou_multiple_dimensions(self):
135+
a = torch.tensor([1, 1, 2, 2])
136+
b = torch.tensor([1, 2, 3, 5])
137+
ab = -0.1250
138+
self.assertTrue(
139+
torch.equal(
140+
box_ops.generalized_box_iou(a[None, :], b[None, :]),
141+
torch.Tensor([[ab]]),
142+
)
143+
)
144+
self.assertTrue(
145+
torch.equal(
146+
box_ops.generalized_box_iou(a[None, None, :], b[None, None, :]),
147+
torch.Tensor([[[ab]]]),
148+
)
149+
)
150+
self.assertTrue(
151+
torch.equal(
152+
box_ops.generalized_box_iou(
153+
a[None, None, None, :], b[None, None, None, :]
154+
),
155+
torch.Tensor([[[[ab]]]]),
156+
)
157+
)
158+
self.assertTrue(
159+
torch.equal(
160+
box_ops.generalized_box_iou(
161+
torch.stack([a, a, b, b]), torch.stack([a, b])
162+
),
163+
torch.Tensor(torch.Tensor([[1, ab], [1, ab], [ab, 1], [ab, 1]])),
164+
)
165+
)
166+
167+
105168
def test_warpped_model_script_detection(self):
106169
class WrappedDETR(nn.Module):
107170
def __init__(self, model):

util/box_ops.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,18 @@ def box_xyxy_to_cxcywh(x):
2020
return torch.stack(b, dim=-1)
2121

2222

23-
# modified from torchvision to also return the union
23+
# Modified from torchvision to also return the union and to work only on the last two dimensions, assuming the other ones are identical.
2424
def box_iou(boxes1, boxes2):
25-
area1 = box_area(boxes1)
26-
area2 = box_area(boxes2)
25+
area1 = box_area(boxes1.view(-1, 4)).view(boxes1.shape[:-1])
26+
area2 = box_area(boxes2.view(-1, 4)).view(boxes2.shape[:-1])
2727

28-
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
29-
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
28+
lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [..., N,M,2]
29+
rb = torch.min(boxes1[..., None, 2:], boxes2[..., None, :, 2:]) # [..., N,M,2]
3030

31-
wh = (rb - lt).clamp(min=0) # [N,M,2]
32-
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
31+
wh = (rb - lt).clamp(min=0) # [..., N,M,2]
32+
inter = wh[..., 0] * wh[..., 1] # [..., N,M]
3333

34-
union = area1[:, None] + area2 - inter
34+
union = area1[..., None] + area2[..., None, :] - inter
3535

3636
iou = inter / union
3737
return iou, union
@@ -48,15 +48,15 @@ def generalized_box_iou(boxes1, boxes2):
4848
"""
4949
# degenerate boxes gives inf / nan results
5050
# so do an early check
51-
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
52-
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
51+
assert (boxes1[..., 2:] >= boxes1[..., :2]).all()
52+
assert (boxes2[..., 2:] >= boxes2[..., :2]).all()
5353
iou, union = box_iou(boxes1, boxes2)
5454

55-
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
56-
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
55+
lt = torch.min(boxes1[..., None, :2], boxes2[..., None, :, :2])
56+
rb = torch.max(boxes1[..., None, 2:], boxes2[..., None, :, 2:])
5757

58-
wh = (rb - lt).clamp(min=0) # [N,M,2]
59-
area = wh[:, :, 0] * wh[:, :, 1]
58+
wh = (rb - lt).clamp(min=0) # [..., N,M,2]
59+
area = wh[..., 0] * wh[..., 1]
6060

6161
return iou - (area - union) / area
6262

0 commit comments

Comments
 (0)