Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions unstructured_inference/models/table_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,16 +364,41 @@ def nms(objects, match_criteria="object2_overlap", match_threshold=0.05, keep_hi
objects = sort_objects_by_score(objects, reverse=keep_higher)

num_objects = len(objects)
suppression = [False for obj in objects]
suppression = [False] * num_objects

# Precompute bboxes and areas to avoid constructing Rect objects in the inner loop.
bboxes = [obj["bbox"] for obj in objects]
areas = []
for bbox in bboxes:
w = bbox[2] - bbox[0]
h = bbox[3] - bbox[1]
a = w * h
areas.append(a if a > 0 else 0.0)


for object2_num in range(1, num_objects):
object2_rect = Rect(objects[object2_num]["bbox"])
object2_area = object2_rect.get_area()
x2_min, y2_min, x2_max, y2_max = bboxes[object2_num]
object2_area = areas[object2_num]
for object1_num in range(object2_num):
if not suppression[object1_num]:
object1_rect = Rect(objects[object1_num]["bbox"])
object1_area = object1_rect.get_area()
intersect_area = object1_rect.intersect(object2_rect).get_area()
x1_min, y1_min, x1_max, y1_max = bboxes[object1_num]
object1_area = areas[object1_num]

# Replicate Rect.intersect behavior:
# If object1_area == 0, then intersect returns object2 (so intersection area = object2_area)
if object1_area == 0.0:
intersect_area = object2_area
else:
inter_w = min(x1_max, x2_max) - max(x1_min, x2_min)
if inter_w <= 0:
intersect_area = 0.0
else:
inter_h = min(y1_max, y2_max) - max(y1_min, y2_min)
if inter_h <= 0:
intersect_area = 0.0
else:
intersect_area = inter_w * inter_h

try:
if match_criteria == "object1_overlap":
metric = intersect_area / object1_area
Expand Down