Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions unstructured_inference/models/table_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,11 +517,26 @@ def nms_supercells(supercells):
num_supercells = len(supercells)
suppression = [False for supercell in supercells]


# Precompute sets for quick overlap checks and update them when a supercell is modified.
row_sets = [set(sc["row_numbers"]) for sc in supercells]
col_sets = [set(sc["column_numbers"]) for sc in supercells]

for supercell2_num in range(1, num_supercells):
supercell2 = supercells[supercell2_num]
row_set2 = row_sets[supercell2_num]
col_set2 = col_sets[supercell2_num]
for supercell1_num in range(supercell2_num):
# Quick skip if there's no possible overlap on rows or columns
if not (row_set2 & row_sets[supercell1_num]) or not (col_set2 & col_sets[supercell1_num]):
continue
supercell1 = supercells[supercell1_num]
remove_supercell_overlap(supercell1, supercell2)
# Update sets for supercell2 because it may have been mutated
row_set2 = set(supercell2["row_numbers"])
col_set2 = set(supercell2["column_numbers"])
row_sets[supercell2_num] = row_set2
col_sets[supercell2_num] = col_set2
if (
(len(supercell2["row_numbers"]) < 2 and len(supercell2["column_numbers"]) < 2)
or len(supercell2["row_numbers"]) == 0
Expand Down Expand Up @@ -574,10 +589,17 @@ def remove_supercell_overlap(supercell1, supercell2):
supercell #1. This resolves the overlap by removing fewer grid cells from
supercell #1 than if we eliminated column C from it.
"""
common_rows = set(supercell1["row_numbers"]).intersection(set(supercell2["row_numbers"]))
common_columns = set(supercell1["column_numbers"]).intersection(
set(supercell2["column_numbers"]),
)
# Local references to avoid repeated dict lookups
rows1 = supercell1["row_numbers"]
rows2 = supercell2["row_numbers"]
cols1 = supercell1["column_numbers"]
cols2 = supercell2["column_numbers"]

common_rows = set(rows1) & set(rows2)
common_columns = set(cols1) & set(cols2)

# While the supercells have overlapping grid cells, continue shrinking the less-confident
# supercell one row or one column at a time

# While the supercells have overlapping grid cells, continue shrinking the less-confident
# supercell one row or one column at a time
Expand All @@ -586,27 +608,32 @@ def remove_supercell_overlap(supercell1, supercell2):
# if the supercell has fewer rows than columns, remove an overlapping column,
# because this removes fewer grid cells from the supercell;
# otherwise remove an overlapping row
if len(supercell2["row_numbers"]) < len(supercell2["column_numbers"]):
min_column = min(supercell2["column_numbers"])
max_column = max(supercell2["column_numbers"])
if len(rows2) < len(cols2):
# compute extremes only when needed
min_column = min(cols2) if cols2 else None
max_column = max(cols2) if cols2 else None
if max_column in common_columns:
common_columns.remove(max_column)
supercell2["column_numbers"].remove(max_column)
# remove from the underlying list (keeps original behavior)
cols2.remove(max_column)
elif min_column in common_columns:
common_columns.remove(min_column)
supercell2["column_numbers"].remove(min_column)
cols2.remove(min_column)
else:
supercell2["column_numbers"] = []
common_columns = set()
# sync local reference
cols2 = supercell2["column_numbers"]
else:
min_row = min(supercell2["row_numbers"])
max_row = max(supercell2["row_numbers"])
min_row = min(rows2) if rows2 else None
max_row = max(rows2) if rows2 else None
if max_row in common_rows:
common_rows.remove(max_row)
supercell2["row_numbers"].remove(max_row)
rows2.remove(max_row)
elif min_row in common_rows:
common_rows.remove(min_row)
supercell2["row_numbers"].remove(min_row)
rows2.remove(min_row)
else:
supercell2["row_numbers"] = []
common_rows = set()
rows2 = supercell2["row_numbers"]