-
Notifications
You must be signed in to change notification settings - Fork 5
/
bbox.py
86 lines (72 loc) · 2.74 KB
/
bbox.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Uses code from the TensorFlow Lite object detection examples:
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# See THIRD-PARTY NOTICES for full attribution and license notice.
# ===============================================================================
"""Classes representing object bounding boxes."""
import collections
Object = collections.namedtuple('Object', ['id', 'score', 'bbox'])
class BBox(collections.namedtuple('BBox', ['xmin', 'ymin', 'xmax', 'ymax'])):
"""
Represents a rectangle which sides are either vertical or horizontal, parallel
to the x or y axis.
"""
__slots__ = ()
@property
def width(self):
"""Returns bounding box width."""
return self.xmax - self.xmin
@property
def height(self):
"""Returns bounding box height."""
return self.ymax - self.ymin
@property
def area(self):
"""Returns bound box area."""
return self.width * self.height
@property
def valid(self):
"""Returns whether bounding box is valid or not.
Valid bounding box has xmin <= xmax and ymin <= ymax which is equivalent to
width >= 0 and height >= 0.
"""
return self.width >= 0 and self.height >= 0
def scale(self, sx, sy):
"""Returns scaled bounding box."""
return BBox(xmin=sx * self.xmin,
ymin=sy * self.ymin,
xmax=sx * self.xmax,
ymax=sy * self.ymax)
def translate(self, dx, dy):
"""Returns translated bounding box."""
return BBox(xmin=dx + self.xmin,
ymin=dy + self.ymin,
xmax=dx + self.xmax,
ymax=dy + self.ymax)
def map(self, f):
"""Returns bounding box modified by applying f for each coordinate."""
return BBox(xmin=f(self.xmin),
ymin=f(self.ymin),
xmax=f(self.xmax),
ymax=f(self.ymax))
@staticmethod
def intersect(a, b):
"""Returns the intersection of two bounding boxes (may be invalid)."""
return BBox(xmin=max(a.xmin, b.xmin),
ymin=max(a.ymin, b.ymin),
xmax=min(a.xmax, b.xmax),
ymax=min(a.ymax, b.ymax))
@staticmethod
def union(a, b):
"""Returns the union of two bounding boxes (always valid)."""
return BBox(xmin=min(a.xmin, b.xmin),
ymin=min(a.ymin, b.ymin),
xmax=max(a.xmax, b.xmax),
ymax=max(a.ymax, b.ymax))
@staticmethod
def iou(a, b):
"""Returns intersection-over-union value."""
intersection = BBox.intersect(a, b)
if not intersection.valid:
return 0.0
area = intersection.area
return area / (a.area + b.area - area)