From f391c74138889b3a1c721033dc0a9f731558a4db Mon Sep 17 00:00:00 2001 From: kyleclo Date: Wed, 5 Jul 2023 11:15:20 -0700 Subject: [PATCH 01/25] formatting --- src/mmda/types/annotation.py | 61 +++++++++++++++--------------------- 1 file changed, 25 insertions(+), 36 deletions(-) diff --git a/src/mmda/types/annotation.py b/src/mmda/types/annotation.py index 4857df5c..7d6f4b36 100644 --- a/src/mmda/types/annotation.py +++ b/src/mmda/types/annotation.py @@ -22,7 +22,6 @@ __all__ = ["Annotation", "BoxGroup", "SpanGroup", "Relation"] - def warn_deepcopy_of_annotation(obj: "Annotation") -> None: """Warns when a deepcopy is performed on an Annotation.""" @@ -34,15 +33,14 @@ def warn_deepcopy_of_annotation(obj: "Annotation") -> None: warnings.warn(msg, UserWarning, stacklevel=2) - class Annotation: """Annotation is intended for storing model predictions for a document.""" def __init__( - self, - id: Optional[int] = None, - doc: Optional['Document'] = None, - metadata: Optional[Metadata] = None + self, + id: Optional[int] = None, + doc: Optional["Document"] = None, + metadata: Optional[Metadata] = None, ): self.id = id self.doc = doc @@ -77,14 +75,14 @@ def __getattr__(self, field: str) -> List["Annotation"]: return self.__getattribute__(field) - class BoxGroup(Annotation): def __init__( - self, - boxes: List[Box], - id: Optional[int] = None, - doc: Optional['Document'] = None, - metadata: Optional[Metadata] = None, + self, + boxes: List[Box], + id: Optional[int] = None, + doc: Optional["Document"] = None, + metadata: Optional[Metadata] = None, + allow_overlap: Optional[bool] = False, ): self.boxes = boxes super().__init__(id=id, doc=doc, metadata=metadata) @@ -93,7 +91,7 @@ def to_json(self) -> Dict: box_group_dict = dict( boxes=[box.to_json() for box in self.boxes], id=self.id, - metadata=self.metadata.to_json() + metadata=self.metadata.to_json(), ) return { key: value for key, value in box_group_dict.items() if value @@ -101,16 +99,13 @@ def to_json(self) -> Dict: @classmethod def from_json(cls, box_group_dict: Dict) -> "BoxGroup": - if "metadata" in box_group_dict: metadata_dict = box_group_dict["metadata"] else: # this fallback is necessary to ensure compatibility with box # groups that were create before the metadata migration and # therefore have "type" in the root of the json dict instead. - metadata_dict = { - "type": box_group_dict.get("type", None) - } + metadata_dict = {"type": box_group_dict.get("type", None)} return cls( boxes=[ @@ -132,7 +127,7 @@ def __deepcopy__(self, memo): box_group = BoxGroup( boxes=deepcopy(self.boxes, memo), id=self.id, - metadata=deepcopy(self.metadata, memo) + metadata=deepcopy(self.metadata, memo), ) # Don't copy an attached document @@ -150,14 +145,13 @@ def type(self, type: Union[str, None]) -> None: class SpanGroup(Annotation): - def __init__( - self, - spans: List[Span], - box_group: Optional[BoxGroup] = None, - id: Optional[int] = None, - doc: Optional['Document'] = None, - metadata: Optional[Metadata] = None, + self, + spans: List[Span], + box_group: Optional[BoxGroup] = None, + id: Optional[int] = None, + doc: Optional["Document"] = None, + metadata: Optional[Metadata] = None, ): self.spans = spans self.box_group = box_group @@ -166,9 +160,7 @@ def __init__( @property def symbols(self) -> List[str]: if self.doc is not None: - return [ - self.doc.symbols[span.start: span.end] for span in self.spans - ] + return [self.doc.symbols[span.start : span.end] for span in self.spans] else: return [] @@ -187,12 +179,10 @@ def to_json(self) -> Dict: spans=[span.to_json() for span in self.spans], id=self.id, metadata=self.metadata.to_json(), - box_group=self.box_group.to_json() if self.box_group else None + box_group=self.box_group.to_json() if self.box_group else None, ) return { - key: value - for key, value in span_group_dict.items() - if value is not None + key: value for key, value in span_group_dict.items() if value is not None } # only serialize non-null values @classmethod @@ -211,7 +201,7 @@ def from_json(cls, span_group_dict: Dict) -> "SpanGroup": # therefore have "id", "type" in the root of the json dict instead. metadata_dict = { "type": span_group_dict.get("type", None), - "text": span_group_dict.get("text", None) + "text": span_group_dict.get("text", None), } return cls( @@ -256,7 +246,7 @@ def __deepcopy__(self, memo): spans=deepcopy(self.spans, memo), id=self.id, metadata=deepcopy(self.metadata, memo), - box_group=deepcopy(self.box_group, memo) + box_group=deepcopy(self.box_group, memo), ) # Don't copy an attached document @@ -284,6 +274,5 @@ def text(self, text: Union[str, None]) -> None: self.metadata.text = text - class Relation(Annotation): - pass \ No newline at end of file + pass From bd0c3ab29521dd79b876fca221e49347bb496051 Mon Sep 17 00:00:00 2001 From: kyleclo Date: Wed, 5 Jul 2023 11:23:08 -0700 Subject: [PATCH 02/25] add functions to Box; reformatting --- src/mmda/types/annotation.py | 15 ++++ src/mmda/types/box.py | 155 +++++++++++++++++++++++++++++------ 2 files changed, 145 insertions(+), 25 deletions(-) diff --git a/src/mmda/types/annotation.py b/src/mmda/types/annotation.py index 7d6f4b36..240d02c0 100644 --- a/src/mmda/types/annotation.py +++ b/src/mmda/types/annotation.py @@ -85,6 +85,13 @@ def __init__( allow_overlap: Optional[bool] = False, ): self.boxes = boxes + if not allow_overlap: + clusters = Box.cluster_boxes(boxes=boxes) + if any([len(cluster) > 1 for cluster in clusters]): + raise ValueError( + "BoxGroup does not allow overlapping boxes. " + "Consider setting allow_overlap=True." + ) super().__init__(id=id, doc=doc, metadata=metadata) def to_json(self) -> Dict: @@ -152,8 +159,16 @@ def __init__( id: Optional[int] = None, doc: Optional["Document"] = None, metadata: Optional[Metadata] = None, + allow_overlap: Optional[bool] = False, ): self.spans = spans + if not allow_overlap: + clusters = Span.cluster_spans(spans=spans) + if any([len(cluster) > 1 for cluster in clusters]): + raise ValueError( + "SpanGroup does not allow overlapping spans. " + "Consider setting allow_overlap=True." + ) self.box_group = box_group super().__init__(id=id, doc=doc, metadata=metadata) diff --git a/src/mmda/types/box.py b/src/mmda/types/box.py index c1bfd4a9..baa94125 100644 --- a/src/mmda/types/box.py +++ b/src/mmda/types/box.py @@ -5,35 +5,58 @@ """ -from typing import List, Dict, Tuple, Union -from dataclasses import dataclass +import logging import warnings +from dataclasses import dataclass +from typing import Dict, List, Tuple, Union + import numpy as np -def is_overlap_1d(start1: float, end1: float, start2: float, end2: float, x: float = 0) -> bool: +def is_overlap_1d( + start1: float, end1: float, start2: float, end2: float, x: float = 0 +) -> bool: """Return whether two 1D intervals overlaps given x""" assert start1 <= end1 assert start2 <= end2 - return not (start1 - x > end2 or start1 > end2 + x or end1 + x < start2 or end1 < start2 - x) # ll # rr + return not ( + start1 - x > end2 or start1 > end2 + x or end1 + x < start2 or end1 < start2 - x + ) # ll # rr -@dataclass class Box: - l: float - t: float - w: float - h: float - page: int + def __init__(self, l: float, t: float, w: float, h: float, page: int) -> None: + if w < 0 or h < 0: + raise ValueError(f"Width and height must be non-negative, got {w} and {h}") + if page < 0: + raise ValueError(f"Page must be non-negative, got {page}") + if l < 0 or t < 0: + raise ValueError(f"Left and top must be non-negative, got {l} and {t}") + self.l = l + self.t = t + self.w = w + self.h = h + self.page = page def to_json(self) -> Dict[str, float]: - return {'left': self.l, 'top': self.t, 'width': self.w, 'height': self.h, 'page': self.page} + return { + "left": self.l, + "top": self.t, + "width": self.w, + "height": self.h, + "page": self.page, + } @classmethod def from_json(cls, box_dict: Dict[str, Union[float, int]]) -> "Box": - return Box(l=box_dict['left'], t=box_dict['top'], w=box_dict['width'], h=box_dict['height'], - page=box_dict['page']) + return Box( + l=box_dict["left"], + t=box_dict["top"], + w=box_dict["width"], + h=box_dict["height"], + page=box_dict["page"], + ) @classmethod def from_coordinates(cls, x1: float, y1: float, x2: float, y2: float, page: int): @@ -74,10 +97,6 @@ def from_pdf_coordinates( @classmethod def small_boxes_to_big_box(cls, boxes: List["Box"]) -> "Box": """Computes one big box that tightly encapsulates all smaller input boxes""" - boxes = [box for box in boxes if box is not None] - if not boxes: - return None - if len({box.page for box in boxes}) != 1: raise ValueError(f"Bboxes not all on same page: {boxes}") x1 = min([bbox.l for bbox in boxes]) @@ -95,11 +114,6 @@ def coordinates(self) -> Tuple[float, float, float, float]: def center(self) -> Tuple[float, float]: return self.l + self.w / 2, self.t + self.h / 2 - @property - def xywh(self) -> Tuple[float, float, float, float]: - """Return a tuple of the (left, top, width, height) format.""" - return self.l, self.t, self.w, self.h - def get_relative(self, page_width: float, page_height: float) -> "Box": """Get the relative coordinates of self based on page_width, page_height.""" return self.__class__( @@ -120,18 +134,109 @@ def get_absolute(self, page_width: int, page_height: int) -> "Box": page=self.page, ) - def is_overlap(self, other: "Box", x: float = 0.0, y: float = 0, center: bool = False) -> bool: + def is_overlap( + self, other: "Box", x: float = 0.0, y: float = 0, center: bool = False + ) -> bool: """ Whether self overlaps with the other Box object. x, y distances for padding center (bool) if True, only consider overlapping if this box's center is contained by other """ + if self.page != other.page: + return False + x11, y11, x12, y12 = self.coordinates x21, y21, x22, y22 = other.coordinates if center: center_x, center_y = self.center - res = is_overlap_1d(center_x, center_x, x21, x22, x) and is_overlap_1d(center_y, center_y, y21, y22, y) + res = is_overlap_1d(center_x, center_x, x21, x22, x) and is_overlap_1d( + center_y, center_y, y21, y22, y + ) else: - res = is_overlap_1d(x11, x12, x21, x22, x) and is_overlap_1d(y11, y12, y21, y22, y) + res = is_overlap_1d(x11, x12, x21, x22, x) and is_overlap_1d( + y11, y12, y21, y22, y + ) return res + + @classmethod + def cluster_boxes(cls, boxes: List["Box"]) -> List[List[int]]: + """ + Cluster boxes into groups based on any overlap. + """ + if not boxes: + return [] + + clusters: List[List[int]] = [[0]] + cluster_id_to_big_box: Dict[int, Box] = {0: boxes[0]} + for box_id in range(1, len(boxes)): + box = boxes[box_id] + + # check all the clusters to see if the box overlaps with any of them + is_overlap = False + for cluster_id, big_box in cluster_id_to_big_box.items(): + if box.is_overlap(big_box, x=0, y=0): + is_overlap = True + break + + # resolve + if is_overlap: + clusters[cluster_id].append(box_id) + cluster_id_to_big_box[cluster_id] = cls.small_boxes_to_big_box( + [box, big_box] + ) + else: + clusters.append([box_id]) + cluster_id_to_big_box[len(clusters) - 1] = box + + # sort clusters + for cluster in clusters: + cluster.sort() + clusters.sort(key=lambda x: x[0]) + + return clusters + + def shrink(self, delta: float, ignore: bool = True, clip: bool = True): + x1, y1, x2, y2 = self.coordinates + if x2 - x1 <= 2 * delta: + if ignore: + logging.warning(f"box's x-coords {self} shrink too much. Ignoring.") + else: + raise ValueError( + f"box's x-coords {self} shrink too much with delta={delta}." + ) + else: + if clip: + logging.warning( + f"box's x-coords {self} go beyond page boundary. Clipping..." + ) + x1 = min(x1 + delta, 1.0) + x2 = max(x2 - delta, 0.0) + else: + raise ValueError( + f"box's x-coordinates {self} go beyond page boundary. need clip." + ) + + if y2 - y1 <= 2 * delta: + if ignore: + logging.warning(f"box's y-coords {self} shrink too much. Ignoring.") + else: + raise ValueError( + f"box's y-coords {self} shrink too much with delta={delta}." + ) + else: + if clip: + logging.warning( + f"box's y-coords {self} go beyond page boundary. Clipping..." + ) + y1 = min(y1 + delta, 1.0) + y2 = max(y2 - delta, 0.0) + else: + raise ValueError( + f"box's y-coordinates {self} go beyond page boundary. need clip." + ) + + self.l = x1 + self.t = y1 + self.w = x2 - x1 + self.h = y2 - y1 From 326bf433638f2b17631922415e3632a77d8aa14e Mon Sep 17 00:00:00 2001 From: kyleclo Date: Wed, 5 Jul 2023 11:24:37 -0700 Subject: [PATCH 03/25] add tests for box --- tests/test_types/test_box.py | 100 +++++++++++++++++++++++++++++++---- 1 file changed, 91 insertions(+), 9 deletions(-) diff --git a/tests/test_types/test_box.py b/tests/test_types/test_box.py index 8528ecc6..9d24d960 100644 --- a/tests/test_types/test_box.py +++ b/tests/test_types/test_box.py @@ -1,18 +1,100 @@ import unittest + from mmda.types import box as mmda_box class TestBox(unittest.TestCase): def setUp(cls) -> None: - cls.box_dict = {'left': 0.2, - 'top': 0.09, - 'width': 0.095, - 'height': 0.017, - 'page': 0} + cls.box_dict = { + "left": 0.2, + "top": 0.09, + "width": 0.095, + "height": 0.017, + "page": 0, + } cls.box = mmda_box.Box(l=0.2, t=0.09, w=0.095, h=0.017, page=0) - def test_from_json(self): - self.assertEqual(self.box.from_json(self.box_dict), self.box) + def test_to_from_json(self): + box = self.box.from_json(self.box_dict) + self.assertDictEqual(box.to_json(), self.box_dict) + + def test_cluster_boxes(self): + # overlapping boxes + boxes = [ + mmda_box.Box(l=0.2, t=0.09, w=0.095, h=0.017, page=0), + mmda_box.Box(l=0.2, t=0.09, w=0.095, h=0.017, page=0), + mmda_box.Box(l=0.2, t=0.09, w=0.095, h=0.017, page=0), + ] + self.assertListEqual(mmda_box.Box.cluster_boxes(boxes), [[0, 1, 2]]) + + # on-overlapping boxes + boxes = [ + mmda_box.Box(l=0.2, t=0.09, w=0.095, h=0.017, page=0), + mmda_box.Box(l=0.3, t=0.20, w=0.095, h=0.017, page=0), + mmda_box.Box(l=0.4, t=0.30, w=0.095, h=0.017, page=0), + ] + self.assertListEqual(mmda_box.Box.cluster_boxes(boxes), [[0], [1], [2]]) + + # partially overlapping boxes + boxes = [ + mmda_box.Box(l=0.2, t=0.09, w=0.095, h=0.017, page=0), + mmda_box.Box(l=0.3, t=0.20, w=0.095, h=0.017, page=0), + mmda_box.Box(l=0.301, t=0.201, w=0.095, h=0.017, page=0), + ] + self.assertListEqual(mmda_box.Box.cluster_boxes(boxes), [[0], [1, 2]]) + + def test_create_invalid_box(self): + box = mmda_box.Box(l=0.7, t=0.4, w=0.3, h=0.6, page=0) + with self.assertRaises(ValueError): + box = mmda_box.Box(l=0.7 + 0.0000001, t=0.2, w=0.3, h=0.4, page=0) + box = mmda_box.Box(l=0.7, t=0.4, w=0.3, h=0.6 + 0.000001, page=0) + with self.assertRaises(ValueError): + box = mmda_box.Box(l=0.7, t=0.4, w=0.3, h=0.6, page=-1) + + def test_shrink(self): + # usual + box = mmda_box.Box(l=0.1, t=0.2, w=0.3, h=0.4, page=0) + box.shrink(delta=0.1) + self.assertAlmostEqual(box.l, 0.2) # 0.1 + 0.1 + self.assertAlmostEqual(box.t, 0.3) # 0.2 + 0.1 + self.assertAlmostEqual(box.w, 0.1) # 0.3 - 0.1 * 2 + self.assertAlmostEqual(box.h, 0.2) # 0.4 - 0.1 * 2 + + # shrinking until inverts box. would ignore shrinking along appropriate axis. + box = mmda_box.Box(l=0.9, t=0.5, w=0.1, h=0.3, page=0) + box.shrink(delta=0.1, ignore=True) + self.assertAlmostEqual(box.l, 0.9) # ignored + self.assertAlmostEqual(box.t, 0.6) # adjusted; 0.5 + 0.1 + self.assertAlmostEqual(box.w, 0.1) # ignored + self.assertAlmostEqual(box.h, 0.1) # adjusted; 0.3 - 2 * 0.1 + + # shrinking until out of bounds. would clip along appropriate axis. + # actually... does this ever happen unless Box is already out of bounds? - def test_to_json(self): - self.assertEqual(self.box.to_json(), self.box_dict) + def test_cluster_boxes_hard(self): + # from 4be952924cd565488b4a239dc6549095029ee578.pdf, page 2, tokens 650:655 + boxes = [ + mmda_box.Box( + l=0.7761069934640523, + t=0.14276190217171716, + w=0.005533858823529373, + h=0.008037272727272593, + page=2, + ), + mmda_box.Box( + l=0.7836408522875816, + t=0.14691867138383832, + w=0.005239432156862763, + h=0.005360666666666692, + page=2, + ), + mmda_box.Box( + l=1.001, t=0.3424244465151515, w=-0.002, h=0.008037272727272737, page=2 + ), + mmda_box.Box( + l=1.001, t=0.3424244465151515, w=-0.002, h=0.008037272727272737, page=2 + ), + mmda_box.Box( + l=1.0, t=0.32670311318181816, w=0.0, h=0.010037272727272737, page=2 + ), + ] From a62faf492842cb2fe35d5c0e38eca372735ab382 Mon Sep 17 00:00:00 2001 From: kyleclo Date: Wed, 5 Jul 2023 11:27:34 -0700 Subject: [PATCH 04/25] tests for span --- tests/test_types/test_span.py | 48 +++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/tests/test_types/test_span.py b/tests/test_types/test_span.py index 53466097..89627ac6 100644 --- a/tests/test_types/test_span.py +++ b/tests/test_types/test_span.py @@ -33,6 +33,54 @@ def test_to_json(self): self.assertEqual(self.span.from_json(self.span_dict).to_json(), self.span_dict) def test_is_overlap(self): + span1 = mmda_span.Span(start=0, end=8) + span2 = mmda_span.Span(start=0, end=8) + self.assertTrue(span1.is_overlap(span2)) + + span3 = mmda_span.Span(start=2, end=5) + self.assertTrue(span1.is_overlap(span3)) + + span4 = mmda_span.Span(start=8, end=10) + self.assertFalse(span1.is_overlap(span4)) + + span5 = mmda_span.Span(start=10, end=12) + self.assertFalse(span1.is_overlap(span5)) + + def small_spans_to_big_span(self): + spans = [ + mmda_span.Span(start=0, end=8), + mmda_span.Span(start=8, end=16), + mmda_span.Span(start=16, end=24), + ] + self.assertEqual( + self.span.small_spans_to_big_span(spans), + mmda_span.Span(start=0, end=24), + ) + + def test_cluster_spans(self): + # overlapping spans + spans = [ + mmda_span.Span(start=0, end=8), + mmda_span.Span(start=0, end=8), + mmda_span.Span(start=0, end=8), + ] + self.assertListEqual(mmda_span.Span.cluster_spans(spans), [[0, 1, 2]]) + + # non-overlapping spans + spans = [ + mmda_span.Span(start=0, end=8), + mmda_span.Span(start=8, end=16), + mmda_span.Span(start=16, end=24), + ] + self.assertListEqual(mmda_span.Span.cluster_spans(spans), [[0], [1], [2]]) + + # partially overlapping spans + spans = [ + mmda_span.Span(start=0, end=8), + mmda_span.Span(start=9, end=16), + mmda_span.Span(start=10, end=15), + ] + self.assertListEqual(mmda_span.Span.cluster_spans(spans), [[0], [1, 2]]) span = mmda_span.Span(start=0, end=2) self.assertTrue(span.is_overlap(mmda_span.Span(start=0, end=1))) self.assertTrue(span.is_overlap(mmda_span.Span(start=1, end=2))) From c0f283efbae9ca6e4b93bf0e332a01c33b60d1b5 Mon Sep 17 00:00:00 2001 From: kyleclo Date: Wed, 5 Jul 2023 11:58:10 -0700 Subject: [PATCH 05/25] fix tests for span; add flag for merging Box behavior --- src/mmda/types/span.py | 51 +++++++++++++++++++++++++++++++++-- tests/test_types/test_span.py | 15 +++++------ 2 files changed, 55 insertions(+), 11 deletions(-) diff --git a/src/mmda/types/span.py b/src/mmda/types/span.py index 2ce68feb..98402898 100644 --- a/src/mmda/types/span.py +++ b/src/mmda/types/span.py @@ -38,7 +38,9 @@ def __lt__(self, other: "Span"): return self.start < other.start @classmethod - def small_spans_to_big_span(cls, spans: List["Span"]) -> "Span": + def small_spans_to_big_span( + cls, spans: List["Span"], merge_boxes: bool = True + ) -> "Span": # TODO: add warning for unsorted spans or not-contiguous spans # TODO: what happens when Boxes cant be merged? start = spans[0].start @@ -48,12 +50,57 @@ def small_spans_to_big_span(cls, spans: List["Span"]) -> "Span": start = span.start if span.end > end: end = span.end + if merge_boxes: + new_box = Box.small_boxes_to_big_box(boxes=[span.box for span in spans]) + else: + new_box = None return Span( start=start, end=end, - box=Box.small_boxes_to_big_box(boxes=[span.box for span in spans]), + box=new_box, ) + def is_overlap(self, other: "Span") -> bool: + return self.start < other.end and other.start < self.end + + @classmethod + def cluster_spans(cls, spans: List["Span"]) -> List[List[int]]: + """ + Cluster spans into groups based on any overlap. + """ + if not spans: + return [] + + clusters: List[List[int]] = [[0]] + cluster_id_to_big_span: Dict[int, Span] = {0: spans[0]} + for span_id in range(1, len(spans)): + span = spans[span_id] + + # check all the clusters to see if the span overlaps with any of them + is_overlap = False + for cluster_id, big_span in cluster_id_to_big_span.items(): + if span.is_overlap(big_span): + is_overlap = True + break + + # resolve + if is_overlap: + clusters[cluster_id].append(span_id) + cluster_id_to_big_span[cluster_id] = cls.small_spans_to_big_span( + [span, big_span], + merge_boxes=False, + ) + else: + clusters.append([span_id]) + cluster_id_to_big_span[len(clusters) - 1] = span + + # sort clusters + for cluster in clusters: + cluster.sort() + clusters.sort(key=lambda x: x[0]) + + return clusters + def is_overlap(self, other: "Span") -> bool: is_self_before_other = self.start < other.end and self.end > other.start is_other_before_self = other.start < self.end and other.end > self.start diff --git a/tests/test_types/test_span.py b/tests/test_types/test_span.py index 89627ac6..844a5b84 100644 --- a/tests/test_types/test_span.py +++ b/tests/test_types/test_span.py @@ -19,19 +19,16 @@ def setUp(cls): }, } - def test_from_json(self): + def test_to_from_json(self): self.assertEqual( - self.span.from_json(self.span_dict), + self.span.from_json(self.span_dict).to_json(), mmda_span.Span( start=0, end=8, box=mmda_box.Box(l=0.2, t=0.09, w=0.095, h=0.017, page=0), - ), + ).to_json(), ) - def test_to_json(self): - self.assertEqual(self.span.from_json(self.span_dict).to_json(), self.span_dict) - def test_is_overlap(self): span1 = mmda_span.Span(start=0, end=8) span2 = mmda_span.Span(start=0, end=8) @@ -46,14 +43,14 @@ def test_is_overlap(self): span5 = mmda_span.Span(start=10, end=12) self.assertFalse(span1.is_overlap(span5)) - def small_spans_to_big_span(self): + def test_small_spans_to_big_span(self): spans = [ mmda_span.Span(start=0, end=8), mmda_span.Span(start=8, end=16), mmda_span.Span(start=16, end=24), ] self.assertEqual( - self.span.small_spans_to_big_span(spans), + self.span.small_spans_to_big_span(spans=spans, merge_boxes=False), mmda_span.Span(start=0, end=24), ) @@ -64,7 +61,7 @@ def test_cluster_spans(self): mmda_span.Span(start=0, end=8), mmda_span.Span(start=0, end=8), ] - self.assertListEqual(mmda_span.Span.cluster_spans(spans), [[0, 1, 2]]) + self.assertListEqual(mmda_span.Span.cluster_spans(spans=spans), [[0, 1, 2]]) # non-overlapping spans spans = [ From 66ec6328f039c650425b41717d83b18c78e71325 Mon Sep 17 00:00:00 2001 From: kyleclo Date: Wed, 5 Jul 2023 12:05:17 -0700 Subject: [PATCH 06/25] fix tests for Box --- src/mmda/types/box.py | 2 ++ tests/test_types/test_box.py | 17 ++++++++++++++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/mmda/types/box.py b/src/mmda/types/box.py index baa94125..c7b96402 100644 --- a/src/mmda/types/box.py +++ b/src/mmda/types/box.py @@ -1,6 +1,8 @@ """ +A Box on a page. Can be in relative or absolute coordinates. +@kylel """ diff --git a/tests/test_types/test_box.py b/tests/test_types/test_box.py index 9d24d960..ba9b1e9e 100644 --- a/tests/test_types/test_box.py +++ b/tests/test_types/test_box.py @@ -44,12 +44,23 @@ def test_cluster_boxes(self): self.assertListEqual(mmda_box.Box.cluster_boxes(boxes), [[0], [1, 2]]) def test_create_invalid_box(self): + # relative coordinate box box = mmda_box.Box(l=0.7, t=0.4, w=0.3, h=0.6, page=0) - with self.assertRaises(ValueError): - box = mmda_box.Box(l=0.7 + 0.0000001, t=0.2, w=0.3, h=0.4, page=0) - box = mmda_box.Box(l=0.7, t=0.4, w=0.3, h=0.6 + 0.000001, page=0) + # absolute coordinate box (larger than 1.0) + box = mmda_box.Box(l=0.7 + 0.0000001, t=0.2, w=0.3, h=0.4, page=0) + box = mmda_box.Box(l=0.7, t=0.4, w=0.3, h=0.6 + 0.000001, page=0) + # negative page with self.assertRaises(ValueError): box = mmda_box.Box(l=0.7, t=0.4, w=0.3, h=0.6, page=-1) + # negative coordinates + with self.assertRaises(ValueError): + box = mmda_box.Box(l=-0.1, t=0.4, w=0.3, h=0.6, page=0) + with self.assertRaises(ValueError): + box = mmda_box.Box(l=0.7, t=-0.1, w=0.3, h=0.6, page=0) + with self.assertRaises(ValueError): + box = mmda_box.Box(l=0.7, t=0.4, w=-0.1, h=0.6, page=0) + with self.assertRaises(ValueError): + box = mmda_box.Box(l=0.7, t=0.4, w=0.3, h=-0.1, page=0) def test_shrink(self): # usual From 128ed4f89bc5638e71c779b620d3e91ef7f7c1be Mon Sep 17 00:00:00 2001 From: kyleclo Date: Wed, 5 Jul 2023 12:07:03 -0700 Subject: [PATCH 07/25] bugfix annotation --- tests/test_types/test_annotation.py | 41 +++++++++++++++-------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/tests/test_types/test_annotation.py b/tests/test_types/test_annotation.py index 83e345fb..331249ef 100644 --- a/tests/test_types/test_annotation.py +++ b/tests/test_types/test_annotation.py @@ -1,36 +1,37 @@ +import unittest + from mmda.types.annotation import BoxGroup from mmda.types.box import Box -import unittest class TestBoxGroup(unittest.TestCase): def setUp(cls) -> None: - cls.box_group_json = {'boxes': [{'left': 0.1, - 'top': 0.6, - 'width': 0.36, - 'height': 0.221, - 'page': 0}], - 'id': None, - 'type': 'Text'} + cls.box_group_json = { + "boxes": [ + {"left": 0.1, "top": 0.6, "width": 0.36, "height": 0.221, "page": 0} + ], + "id": None, + "type": "Text", + } def test_from_json(self): self.assertIsInstance(BoxGroup.from_json(self.box_group_json), BoxGroup) - self.assertEqual(BoxGroup.from_json(self.box_group_json).boxes, - [Box(l=0.1, t=0.6, w=0.36, h=0.221, page=0)]) + self.assertEqual( + BoxGroup.from_json(self.box_group_json).boxes[0].to_json(), + Box(l=0.1, t=0.6, w=0.36, h=0.221, page=0).to_json(), + ) self.assertEqual(BoxGroup.from_json(self.box_group_json).id, None) - self.assertEqual(BoxGroup.from_json(self.box_group_json).type, 'Text') + self.assertEqual(BoxGroup.from_json(self.box_group_json).type, "Text") def test_to_json(self): boxgroup = BoxGroup.from_json(self.box_group_json) self.assertIsInstance(boxgroup.to_json(), dict) - self.assertEqual(boxgroup.to_json()['boxes'], - [{'left': 0.1, - 'top': 0.6, - 'width': 0.36, - 'height': 0.221, - 'page': 0}]) - - assert 'boxes' in boxgroup.to_json() - assert 'metadata' in boxgroup.to_json() + self.assertEqual( + boxgroup.to_json()["boxes"], + [{"left": 0.1, "top": 0.6, "width": 0.36, "height": 0.221, "page": 0}], + ) + + assert "boxes" in boxgroup.to_json() + assert "metadata" in boxgroup.to_json() From 966dc136dba46e097f55ebcde811b734fb443c65 Mon Sep 17 00:00:00 2001 From: kyleclo Date: Wed, 5 Jul 2023 12:08:37 -0700 Subject: [PATCH 08/25] comment out unfinished test --- tests/test_types/test_box.py | 52 +++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/tests/test_types/test_box.py b/tests/test_types/test_box.py index ba9b1e9e..ea3194f8 100644 --- a/tests/test_types/test_box.py +++ b/tests/test_types/test_box.py @@ -84,28 +84,30 @@ def test_shrink(self): def test_cluster_boxes_hard(self): # from 4be952924cd565488b4a239dc6549095029ee578.pdf, page 2, tokens 650:655 - boxes = [ - mmda_box.Box( - l=0.7761069934640523, - t=0.14276190217171716, - w=0.005533858823529373, - h=0.008037272727272593, - page=2, - ), - mmda_box.Box( - l=0.7836408522875816, - t=0.14691867138383832, - w=0.005239432156862763, - h=0.005360666666666692, - page=2, - ), - mmda_box.Box( - l=1.001, t=0.3424244465151515, w=-0.002, h=0.008037272727272737, page=2 - ), - mmda_box.Box( - l=1.001, t=0.3424244465151515, w=-0.002, h=0.008037272727272737, page=2 - ), - mmda_box.Box( - l=1.0, t=0.32670311318181816, w=0.0, h=0.010037272727272737, page=2 - ), - ] + # boxes = [ + # mmda_box.Box( + # l=0.7761069934640523, + # t=0.14276190217171716, + # w=0.005533858823529373, + # h=0.008037272727272593, + # page=2, + # ), + # mmda_box.Box( + # l=0.7836408522875816, + # t=0.14691867138383832, + # w=0.005239432156862763, + # h=0.005360666666666692, + # page=2, + # ), + # mmda_box.Box( + # l=1.001, t=0.3424244465151515, w=-0.002, h=0.008037272727272737, page=2 + # ), + # mmda_box.Box( + # l=1.001, t=0.3424244465151515, w=-0.002, h=0.008037272727272737, page=2 + # ), + # mmda_box.Box( + # l=1.0, t=0.32670311318181816, w=0.0, h=0.010037272727272737, page=2 + # ), + # ] + # TODO: unfinished test + pass From 78b7c3b72189a25ddbaf79eb763195a25f884941 Mon Sep 17 00:00:00 2001 From: kyleclo Date: Wed, 5 Jul 2023 12:12:42 -0700 Subject: [PATCH 09/25] fix bug; span equality --- tests/test_types/test_json_conversion.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/test_types/test_json_conversion.py b/tests/test_types/test_json_conversion.py index e7a5f27d..765aba4d 100644 --- a/tests/test_types/test_json_conversion.py +++ b/tests/test_types/test_json_conversion.py @@ -1,16 +1,15 @@ -''' +""" Description: Test whether all properties for an mmda doc are preserved when converting to json and back. Author: @soldni -''' +""" import json from pathlib import Path -from mmda.types import BoxGroup, SpanGroup, Document, Metadata from mmda.parsers import PDFPlumberParser - +from mmda.types import BoxGroup, Document, Metadata, SpanGroup PDFFILEPATH = Path(__file__).parent / "../fixtures/1903.10676.pdf" @@ -45,10 +44,7 @@ def test_doc_conversion(): for field_name in orig_doc.fields: # this iterates over all span group for this field in both docs - field_it = zip( - getattr(orig_doc, field_name), - getattr(new_doc, field_name) - ) + field_it = zip(getattr(orig_doc, field_name), getattr(new_doc, field_name)) # type annotations to keep mypy quiet orig_sg: SpanGroup @@ -58,4 +54,7 @@ def test_doc_conversion(): # for each pair, they should have same metadata (type, id, # and optionally, text) and same spans. assert orig_sg.metadata == new_sg.metadata - assert orig_sg.spans == new_sg.spans + for orig_span, new_span in zip(orig_sg.spans, new_sg.spans): + assert orig_span.start == new_span.start + assert orig_span.end == new_span.end + assert orig_span.box.coordinates == new_span.box.coordinates From 013b512a1341834b7dcc629cfd597c393de2e5f4 Mon Sep 17 00:00:00 2001 From: kyleclo Date: Wed, 5 Jul 2023 12:22:13 -0700 Subject: [PATCH 10/25] add boxgroup indexer; add tests for indexers --- src/mmda/types/indexers.py | 109 +++++++++++++++++++++++++++--- tests/test_types/test_indexers.py | 106 +++++++++++++++++++---------- 2 files changed, 170 insertions(+), 45 deletions(-) diff --git a/src/mmda/types/indexers.py b/src/mmda/types/indexers.py index beb12b1f..828d7094 100644 --- a/src/mmda/types/indexers.py +++ b/src/mmda/types/indexers.py @@ -4,15 +4,16 @@ """ -from typing import List - from abc import abstractmethod +from collections import defaultdict from dataclasses import dataclass, field +from typing import List -from mmda.types.annotation import SpanGroup, Annotation -from ncls import NCLS import numpy as np import pandas as pd +from ncls import NCLS + +from mmda.types.annotation import Annotation, Box, BoxGroup, SpanGroup @dataclass @@ -56,7 +57,7 @@ def __init__(self, span_groups: List[SpanGroup]) -> None: self._index = NCLS( pd.Series(starts, dtype=np.int64), pd.Series(ends, dtype=np.int64), - pd.Series(ids, dtype=np.int64) + pd.Series(ids, dtype=np.int64), ) self._ensure_disjoint() @@ -68,15 +69,26 @@ def _ensure_disjoint(self) -> None: """ for span_group in self._sgs: for span in span_group.spans: - matches = [match for match in self._index.find_overlap(span.start, span.end)] - if len(matches) > 1: + match_ids = [ + matched_id + for _start, _end, matched_id in self._index.find_overlap( + span.start, span.end + ) + ] + if len(match_ids) > 1: + matches = [self._sgs[match_id].to_json() for match_id in match_ids] raise ValueError( - f"Detected overlap with existing SpanGroup(s) {matches} for {span_group}" + f"Detected overlap! While processing the Span {span} as part of query SpanGroup {span_group.to_json()}, we found that it overlaps with existing SpanGroup(s):\n" + + "\n".join( + [f"\t{i}\t{m} " for i, m in zip(match_ids, matches)] + ) ) def find(self, query: SpanGroup) -> List[SpanGroup]: if not isinstance(query, SpanGroup): - raise ValueError(f'SpanGroupIndexer only works with `query` that is SpanGroup type') + raise ValueError( + f"SpanGroupIndexer only works with `query` that is SpanGroup type" + ) if not query.spans: return [] @@ -84,7 +96,9 @@ def find(self, query: SpanGroup) -> List[SpanGroup]: matched_ids = set() for span in query.spans: - for _start, _end, matched_id in self._index.find_overlap(span.start, span.end): + for _start, _end, matched_id in self._index.find_overlap( + span.start, span.end + ): matched_ids.add(matched_id) matched_span_groups = [self._sgs[matched_id] for matched_id in matched_ids] @@ -95,3 +109,78 @@ def find(self, query: SpanGroup) -> List[SpanGroup]: return sorted(list(matched_span_groups)) +class BoxGroupIndexer(Indexer): + """ + Manages a data structure for locating overlapping BoxGroups. + Builds a static nested containment list from BoxGroups + and accepts other BoxGroups as search probes. + """ + + def __init__(self, box_groups: List[BoxGroup]) -> None: + self._bgs = box_groups + + self._box_id_to_box_group_id = {} + self._boxes = [] + box_id = 0 + for bg_id, bg in enumerate(box_groups): + for box in bg.boxes: + self._boxes.append(box) + self._box_id_to_box_group_id[box_id] = bg_id + box_id += 1 + + self._np_boxes_x1 = np.array([b.l for b in self._boxes]) + self._np_boxes_y1 = np.array([b.t for b in self._boxes]) + self._np_boxes_x2 = np.array([b.l + b.w for b in self._boxes]) + self._np_boxes_y2 = np.array([b.t + b.h for b in self._boxes]) + self._np_boxes_page = np.array([b.page for b in self._boxes]) + + self._ensure_disjoint() + + def _find_overlap_boxes(self, query: Box) -> List[int]: + x1, y1, x2, y2 = query.coordinates + mask = ( + (self._np_boxes_x1 <= x2) + & (self._np_boxes_x2 >= x1) + & (self._np_boxes_y1 <= y2) + & (self._np_boxes_y2 >= y1) + & (self._np_boxes_page == query.page) + ) + return np.where(mask)[0].tolist() + + def _find_overlap_box_groups(self, query: Box) -> List[int]: + return [ + self._box_id_to_box_group_id[box_id] + for box_id in self._find_overlap_boxes(query) + ] + + def _ensure_disjoint(self) -> None: + """ + Constituent box groups must be fully disjoint. + Ensure the integrity of the built index. + """ + for box_group in self._bgs: + for box in box_group.boxes: + match_ids = self._find_overlap_box_groups(query=box) + if len(match_ids) > 1: + matches = [self._bgs[match_id].to_json() for match_id in match_ids] + raise ValueError( + f"Detected overlap! While processing the Box {box} as part of query BoxGroup {box_group.to_json()}, we found that it overlaps with existing BoxGroup(s):\n" + + "\n".join( + [f"\t{i}\t{m} " for i, m in zip(match_ids, matches)] + ) + ) + + def find(self, query: BoxGroup) -> List[BoxGroup]: + if not isinstance(query, BoxGroup): + raise ValueError( + f"BoxGroupIndexer only works with `query` that is BoxGroup type" + ) + + if not query.boxes: + return [] + + match_ids = [] + for box in query.boxes: + match_ids.extend(self._find_overlap_box_groups(query=box)) + + return [self._bgs[match_id] for match_id in sorted(set(match_ids))] diff --git a/tests/test_types/test_indexers.py b/tests/test_types/test_indexers.py index 05ab6885..2cb7b33c 100644 --- a/tests/test_types/test_indexers.py +++ b/tests/test_types/test_indexers.py @@ -1,19 +1,13 @@ import unittest -from mmda.types import SpanGroup, Span -from mmda.types.indexers import SpanGroupIndexer +from mmda.types import Box, BoxGroup, Span, SpanGroup +from mmda.types.indexers import BoxGroupIndexer, SpanGroupIndexer class TestSpanGroupIndexer(unittest.TestCase): def test_overlap_within_single_spangroup_fails_checks(self): span_groups = [ - SpanGroup( - id=1, - spans=[ - Span(0, 5), - Span(4, 7) - ] - ) + SpanGroup(id=1, spans=[Span(0, 5), Span(4, 7)], allow_overlap=True) ] with self.assertRaises(ValueError): @@ -21,17 +15,8 @@ def test_overlap_within_single_spangroup_fails_checks(self): def test_overlap_between_spangroups_fails_checks(self): span_groups = [ - SpanGroup( - id=1, - spans=[ - Span(0, 5), - Span(5, 8) - ] - ), - SpanGroup( - id=2, - spans=[Span(6, 10)] - ) + SpanGroup(id=1, spans=[Span(0, 5), Span(5, 8)]), + SpanGroup(id=2, spans=[Span(6, 10)]), ] with self.assertRaises(ValueError): @@ -39,21 +24,9 @@ def test_overlap_between_spangroups_fails_checks(self): def test_finds_matching_groups_in_doc_order(self): span_groups_to_index = [ - SpanGroup( - id=1, - spans=[ - Span(0, 5), - Span(5, 8) - ] - ), - SpanGroup( - id=2, - spans=[Span(9, 10)] - ), - SpanGroup( - id=3, - spans=[Span(100, 105)] - ) + SpanGroup(id=1, spans=[Span(0, 5), Span(5, 8)]), + SpanGroup(id=2, spans=[Span(9, 10)]), + SpanGroup(id=3, spans=[Span(100, 105)]), ] index = SpanGroupIndexer(span_groups_to_index) @@ -66,4 +39,67 @@ def test_finds_matching_groups_in_doc_order(self): self.assertEqual(matches, [span_groups_to_index[0], span_groups_to_index[1]]) +class TestBoxGroupIndexer(unittest.TestCase): + def test_overlap_within_single_boxgroup_fails_checks(self): + box_groups = [ + BoxGroup( + id=1, + boxes=[Box(0, 0, 5, 5, page=0), Box(4, 4, 7, 7, page=0)], + allow_overlap=True, + ) + ] + + with self.assertRaises(ValueError): + BoxGroupIndexer(box_groups) + + def test_overlap_between_boxgroups_fails_checks(self): + box_groups = [ + BoxGroup( + id=1, boxes=[Box(0, 0, 5, 5, page=0), Box(5.01, 5.01, 8, 8, page=0)] + ), + BoxGroup(id=2, boxes=[Box(6, 6, 10, 10, page=0)]), + ] + + with self.assertRaises(ValueError): + BoxGroupIndexer(box_groups) + def test_finds_matching_groups_in_doc_order(self): + box_groups_to_index = [ + BoxGroup(id=1, boxes=[Box(0, 0, 1, 1, page=0), Box(2, 2, 1, 1, page=0)]), + BoxGroup(id=2, boxes=[Box(4, 4, 1, 1, page=0)]), + BoxGroup(id=3, boxes=[Box(100, 100, 1, 1, page=0)]), + ] + + index = BoxGroupIndexer(box_groups_to_index) + + # should intersect 1 and 2 but not 3 + probe = BoxGroup(id=4, boxes=[Box(1, 1, 5, 5, page=0), Box(9, 9, 5, 5, page=0)]) + matches = index.find(probe) + + self.assertEqual(len(matches), 2) + self.assertEqual(matches, [box_groups_to_index[0], box_groups_to_index[1]]) + + def test_finds_matching_groups_accounts_for_pages(self): + box_groups_to_index = [ + BoxGroup(id=1, boxes=[Box(0, 0, 1, 1, page=0), Box(2, 2, 1, 1, page=1)]), + BoxGroup(id=2, boxes=[Box(4, 4, 1, 1, page=1)]), + BoxGroup(id=3, boxes=[Box(100, 100, 1, 1, page=0)]), + ] + + index = BoxGroupIndexer(box_groups_to_index) + + # shouldnt intersect any given page 0 + probe = BoxGroup(id=4, boxes=[Box(1, 1, 5, 5, page=0), Box(9, 9, 5, 5, page=0)]) + matches = index.find(probe) + + self.assertEqual(len(matches), 1) + self.assertEqual(matches, [box_groups_to_index[0]]) + + # shoudl intersect after switching to page 1 (and the page 2 box doesnt intersect) + probe = BoxGroup( + id=4, boxes=[Box(1, 1, 5, 5, page=1), Box(100, 100, 1, 1, page=2)] + ) + matches = index.find(probe) + + self.assertEqual(len(matches), 2) + self.assertEqual(matches, [box_groups_to_index[0], box_groups_to_index[1]]) From caf81c314ba44da520d494c20688c6fe46e4c012 Mon Sep 17 00:00:00 2001 From: kyleclo Date: Wed, 5 Jul 2023 13:20:45 -0700 Subject: [PATCH 11/25] bugfix unit test --- tests/test_types/test_document.py | 41 ++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/tests/test_types/test_document.py b/tests/test_types/test_document.py index 355e4993..13256f94 100644 --- a/tests/test_types/test_document.py +++ b/tests/test_types/test_document.py @@ -1,13 +1,12 @@ import json -import unittest import os +import unittest +from ai2_internal import api from mmda.types.annotation import SpanGroup from mmda.types.document import Document from mmda.types.names import MetadataField, SymbolsField -from ai2_internal import api - def resolve(file: str) -> str: return os.path.join(os.path.dirname(__file__), "../fixtures/types", file) @@ -58,10 +57,15 @@ def test_annotate_box_groups_gets_text(self): spp_doc = Document.from_json(json.load(f)) with open(resolve("test_document_box_groups.json")) as f: - box_groups = [api.BoxGroup(**bg).to_mmda() for bg in json.load(f)["grobid_bibs_box_groups"]] + box_groups = [ + api.BoxGroup(**bg).to_mmda() + for bg in json.load(f)["grobid_bibs_box_groups"] + ] spp_doc.annotate(new_span_groups=box_groups) - assert spp_doc.new_span_groups[0].text.startswith("Gutman G, Rosenzweig D, Golan J") + assert spp_doc.new_span_groups[0].text.startswith( + "Gutman G, Rosenzweig D, Golan J" + ) # when token boxes are on spans plumber_doc = "c8b53e2d9cd247e2d42719e337bfb13784d22bd2.json" @@ -70,7 +74,10 @@ def test_annotate_box_groups_gets_text(self): doc = Document.from_json(json.load(f)) with open(resolve("test_document_box_groups.json")) as f: - box_groups = [api.BoxGroup(**bg).to_mmda() for bg in json.load(f)["grobid_bibs_box_groups"]] + box_groups = [ + api.BoxGroup(**bg).to_mmda() + for bg in json.load(f)["grobid_bibs_box_groups"] + ] doc.annotate(new_span_groups=box_groups) assert doc.new_span_groups[0].text.startswith("Gutman G, Rosenzweig D, Golan J") @@ -80,7 +87,10 @@ def test_annotate_box_groups_allocates_all_overlapping_tokens(self): a last-known-good fixture """ # basic doc annotated with pages and tokens, from pdfplumber parser split at punctuation - with open(resolve("20fdafb68d0e69d193527a9a1cbe64e7e69a3798__pdfplumber_doc.json"), "r") as f: + with open( + resolve("20fdafb68d0e69d193527a9a1cbe64e7e69a3798__pdfplumber_doc.json"), + "r", + ) as f: raw_json = f.read() fixture_doc_json = json.loads(raw_json) doc = Document.from_json(fixture_doc_json) @@ -88,9 +98,16 @@ def test_annotate_box_groups_allocates_all_overlapping_tokens(self): # spangroups derived from boxgroups of boxes drawn neatly around bib entries by calling `.annotate` on # list of BoxGroups fixture_span_groups = [] - with open(resolve("20fdafb68d0e69d193527a9a1cbe64e7e69a3798__bib_entry_span_groups_from_box_groups.json"), "r") as f: + with open( + resolve( + "20fdafb68d0e69d193527a9a1cbe64e7e69a3798__bib_entry_span_groups_from_box_groups.json" + ), + "r", + ) as f: raw_json = f.read() - fixture_bib_entries_json = json.loads(raw_json)["bib_entry_span_groups_from_box_groups"] + fixture_bib_entries_json = json.loads(raw_json)[ + "bib_entry_span_groups_from_box_groups" + ] # make box_groups to annotate from test fixture bib entry span groups, and save the for bib_entry in fixture_bib_entries_json: @@ -105,6 +122,8 @@ def test_annotate_box_groups_allocates_all_overlapping_tokens(self): doc.annotate(fixture_box_groups=fixture_box_groups) for sg1, sg2 in zip(fixture_span_groups, doc.fixture_box_groups): - assert sg1.spans == sg2.spans assert sg1.text == sg2.text - + for sg1_span, sg2_span in zip(sg1.spans, sg2.spans): + assert sg1_span.start == sg2_span.start + assert sg1_span.end == sg2_span.end + assert sg1_span.box.coordinates == sg2_span.box.coordinates From d7afec35d3a2fa2de35616687e2f8a405931d74f Mon Sep 17 00:00:00 2001 From: kyleclo Date: Wed, 5 Jul 2023 13:24:39 -0700 Subject: [PATCH 12/25] bugfix; metrics --- src/mmda/eval/metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mmda/eval/metrics.py b/src/mmda/eval/metrics.py index abc431f9..dfdd36ce 100644 --- a/src/mmda/eval/metrics.py +++ b/src/mmda/eval/metrics.py @@ -75,11 +75,11 @@ def levenshtein( def box_overlap(box: Box, container: Box) -> float: """Returns the percentage of area of a box inside of a container.""" - bl, bt, bw, bh = box.xywh + bl, bt, bw, bh = box.l, box.t, box.w, box.h br = bl + bw bb = bt + bh - cl, ct, cw, ch = container.xywh + cl, ct, cw, ch = container.l, container.t, container.w, container.h cr = cl + cw cb = ct + ch From 37b223e05769530ccc374675f3a7bd3b4d8fa710 Mon Sep 17 00:00:00 2001 From: kyleclo Date: Wed, 5 Jul 2023 13:28:21 -0700 Subject: [PATCH 13/25] bugfix test --- tests/test_parsers/test_pdf_plumber_parser.py | 36 +++++++++++++------ 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/tests/test_parsers/test_pdf_plumber_parser.py b/tests/test_parsers/test_pdf_plumber_parser.py index 6ea06117..ce9a7280 100644 --- a/tests/test_parsers/test_pdf_plumber_parser.py +++ b/tests/test_parsers/test_pdf_plumber_parser.py @@ -19,7 +19,7 @@ class TestPDFPlumberParser(unittest.TestCase): def setUp(cls) -> None: cls.fixture_path = pathlib.Path(__file__).parent.parent / "fixtures" - ''' + """ def test_parse(self): parser = PDFPlumberParser() doc = parser.parse(input_pdf_path=self.fixture_path / "1903.10676.pdf") @@ -207,7 +207,7 @@ def test_convert_nested_text_to_doc_json(self): "abc\nd ef", "gh i\njkl", ] - ''' + """ def test_parser_stability(self): """ @@ -223,15 +223,25 @@ def test_parser_stability(self): parser = PDFPlumberParser() - current_doc = parser.parse(input_pdf_path=self.fixture_path / "4be952924cd565488b4a239dc6549095029ee578.pdf") + current_doc = parser.parse( + input_pdf_path=self.fixture_path + / "4be952924cd565488b4a239dc6549095029ee578.pdf" + ) - with open(self.fixture_path / "4be952924cd565488b4a239dc6549095029ee578__pdfplumber_doc.json", "r") as f: + with open( + self.fixture_path + / "4be952924cd565488b4a239dc6549095029ee578__pdfplumber_doc.json", + "r", + ) as f: raw_json = f.read() fixture_doc_json = json.loads(raw_json) fixture_doc = Document.from_json(fixture_doc_json) - - self.assertEqual(current_doc.symbols, fixture_doc.symbols, msg="Current parse has extracted different text from pdf.") + self.assertEqual( + current_doc.symbols, + fixture_doc.symbols, + msg="Current parse has extracted different text from pdf.", + ) def compare_span_groups(current_doc_sgs, fixture_doc_sgs, annotation_name): current_doc_sgs_simplified = [ @@ -244,17 +254,23 @@ def compare_span_groups(current_doc_sgs, fixture_doc_sgs, annotation_name): self.assertEqual( current_doc_sgs_simplified, fixture_doc_sgs_simplified, - msg=f"Current parse produces different SpanGroups for `{annotation_name}`" + msg=f"Current parse produces different SpanGroups for `{annotation_name}`", ) - current_doc_sg_boxes = [[list(s.box.xywh) + [s.box.page] for s in sg] for sg in current_doc_sgs] - fixture_doc_sg_boxes = [[list(s.box.xywh) + [s.box.page] for s in sg] for sg in current_doc_sgs] + current_doc_sg_boxes = [ + [list(s.box.coordinates) + [s.box.page] for s in sg] + for sg in current_doc_sgs + ] + fixture_doc_sg_boxes = [ + [list(s.box.coordinates) + [s.box.page] for s in sg] + for sg in current_doc_sgs + ] self.assertAlmostEqual( current_doc_sg_boxes, fixture_doc_sg_boxes, places=3, - msg=f"Boxes generated for `{annotation_name}` have changed." + msg=f"Boxes generated for `{annotation_name}` have changed.", ) compare_span_groups(current_doc.tokens, fixture_doc.tokens, "tokens") From 439307bb56c91d8991541113f42a81726ba17b2a Mon Sep 17 00:00:00 2001 From: kyleclo Date: Wed, 5 Jul 2023 17:31:33 -0700 Subject: [PATCH 14/25] CHANGE! word predictor doesnt return merged Box anymore as part of its SpanGroup --- .../predictors/sklearn_predictors/svm_word_predictor.py | 8 ++++++-- tests/test_predictors/test_svm_word_predictor.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/mmda/predictors/sklearn_predictors/svm_word_predictor.py b/src/mmda/predictors/sklearn_predictors/svm_word_predictor.py index 13bdb2f6..bc926489 100644 --- a/src/mmda/predictors/sklearn_predictors/svm_word_predictor.py +++ b/src/mmda/predictors/sklearn_predictors/svm_word_predictor.py @@ -451,7 +451,10 @@ def _create_words( else: spans = [ Span.small_spans_to_big_span( - spans=[span for token in tokens_in_word for span in token.spans] + spans=[ + span for token in tokens_in_word for span in token.spans + ], + merge_boxes=False, ) ] metadata = ( @@ -471,7 +474,8 @@ def _create_words( # last bit spans = [ Span.small_spans_to_big_span( - spans=[span for token in tokens_in_word for span in token.spans] + spans=[span for token in tokens_in_word for span in token.spans], + merge_boxes=False, ) ] metadata = ( diff --git a/tests/test_predictors/test_svm_word_predictor.py b/tests/test_predictors/test_svm_word_predictor.py index 308229d7..f4a179f8 100644 --- a/tests/test_predictors/test_svm_word_predictor.py +++ b/tests/test_predictors/test_svm_word_predictor.py @@ -268,7 +268,7 @@ def test_cluster_tokens_by_whitespace(self): else: tokens = [self.doc.tokens[token_id] for token_id in token_ids] spans = [span for token in tokens for span in token.spans] - big_span = Span.small_spans_to_big_span(spans=spans) + big_span = Span.small_spans_to_big_span(spans=spans, merge_boxes=False) self.assertEqual( self.doc.symbols[big_span.start : big_span.end], "".join([token.text for token in tokens]), From db07e68769cf2427c17dfe62accc7bfdb9e88b2a Mon Sep 17 00:00:00 2001 From: kyleclo Date: Wed, 5 Jul 2023 20:28:40 -0700 Subject: [PATCH 15/25] reformat; remove unused imports from Doc --- src/mmda/types/document.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/mmda/types/document.py b/src/mmda/types/document.py index 4f2457d3..ab5760fd 100644 --- a/src/mmda/types/document.py +++ b/src/mmda/types/document.py @@ -14,7 +14,7 @@ from mmda.types.indexers import Indexer, SpanGroupIndexer from mmda.types.metadata import Metadata from mmda.types.names import ImagesField, MetadataField, SymbolsField -from mmda.utils.tools import MergeSpans, allocate_overlapping_tokens_for_box, box_groups_to_span_groups +from mmda.utils.tools import box_groups_to_span_groups class Document: @@ -46,7 +46,7 @@ def add_metadata(self, **kwargs): self.metadata.set(k, value) def annotate( - self, is_overwrite: bool = False, **kwargs: Iterable[Annotation] + self, is_overwrite: bool = False, **kwargs: Iterable[Annotation] ) -> None: """Annotate the fields for document symbols (correlating the annotations with the symbols) and store them into the papers. @@ -54,7 +54,7 @@ def annotate( # 1) check validity of field names for field_name in kwargs.keys(): assert ( - field_name not in self.SPECIAL_FIELDS + field_name not in self.SPECIAL_FIELDS ), f"The field_name {field_name} should not be in {self.SPECIAL_FIELDS}." if field_name in self.fields: @@ -83,7 +83,7 @@ def annotate( annotation_types = {type(a) for a in annotations} assert ( - len(annotation_types) == 1 + len(annotation_types) == 1 ), f"Annotations in field_name {field_name} more than 1 type: {annotation_types}" annotation_type = annotation_types.pop() @@ -94,7 +94,8 @@ def annotate( elif annotation_type == BoxGroup: # TODO: not good. BoxGroups should be stored on their own, not auto-generating SpanGroups. span_groups = self._annotate_span_group( - span_groups=box_groups_to_span_groups(annotations, self), field_name=field_name + span_groups=box_groups_to_span_groups(annotations, self), + field_name=field_name, ) else: raise NotImplementedError( @@ -111,7 +112,7 @@ def remove(self, field_name: str): del self.__indexers[field_name] def annotate_images( - self, images: Iterable[PILImage], is_overwrite: bool = False + self, images: Iterable[PILImage], is_overwrite: bool = False ) -> None: if not is_overwrite and len(self.images) > 0: raise AssertionError( @@ -133,7 +134,7 @@ def annotate_images( self.images = images def _annotate_span_group( - self, span_groups: List[SpanGroup], field_name: str + self, span_groups: List[SpanGroup], field_name: str ) -> List[SpanGroup]: """Annotate the Document using a bunch of span groups. It will associate the annotations with the document symbols. From 039f0dc5226c04a3425a579ab93f42ffc82c27df Mon Sep 17 00:00:00 2001 From: kyleclo Date: Wed, 5 Jul 2023 20:29:21 -0700 Subject: [PATCH 16/25] reformatting --- .../test_figure_table_predictors.py | 317 ++++++++++++------ 1 file changed, 206 insertions(+), 111 deletions(-) diff --git a/tests/test_predictors/test_figure_table_predictors.py b/tests/test_predictors/test_figure_table_predictors.py index 1151f6b6..d0f6fa7b 100644 --- a/tests/test_predictors/test_figure_table_predictors.py +++ b/tests/test_predictors/test_figure_table_predictors.py @@ -1,54 +1,71 @@ import json +import pathlib import pickle import unittest from collections import defaultdict -import pathlib + import pytest from ai2_internal.api import Relation -from mmda.predictors.heuristic_predictors.figure_table_predictors import FigureTablePredictions -from mmda.types import Document, BoxGroup -from mmda.types.box import Box -from mmda.types.span import Span +from mmda.predictors.heuristic_predictors.figure_table_predictors import ( + FigureTablePredictions, +) +from mmda.types import Box, BoxGroup, Document, Span class TestFigureCaptionPredictor(unittest.TestCase): @classmethod def setUp(cls): cls.fixture_path = pathlib.Path(__file__).parent.parent - with open(cls.fixture_path / 'fixtures/doc_fixture_e5910c027af0ee9c1901c57f6579d903aedee7f4.pkl', - 'rb') as file_handle: + with open( + cls.fixture_path + / "fixtures/doc_fixture_e5910c027af0ee9c1901c57f6579d903aedee7f4.pkl", + "rb", + ) as file_handle: doc_json = pickle.load(file_handle) cls.doc = Document.from_json(doc_json) assert cls.doc.pages assert cls.doc.tokens assert cls.doc.blocks assert cls.doc.vila_span_groups - with open(cls.fixture_path / 'fixtures/doc_fixture_2149e0c1106e6dfa36ea787167d6611cf88b69cb.json', - 'rb') as file_handle: + with open( + cls.fixture_path + / "fixtures/doc_fixture_2149e0c1106e6dfa36ea787167d6611cf88b69cb.json", + "rb", + ) as file_handle: dic_json = json.load(file_handle) - cls.doc_2 = Document.from_json(dic_json['doc']) - layout_equations = [BoxGroup.from_json(entry) for entry in dic_json['layout_equations']] + cls.doc_2 = Document.from_json(dic_json["doc"]) + layout_equations = [ + BoxGroup.from_json(entry) for entry in dic_json["layout_equations"] + ] cls.doc_2.annotate(blocks=layout_equations) - with open(cls.fixture_path / 'fixtures/figure_table_predictions.json', 'r') as file: + with open( + cls.fixture_path / "fixtures/figure_table_predictions.json", "r" + ) as file: cls.figure_predictions = json.load(file) cls.figure_table_predictor = FigureTablePredictions(cls.doc) def test_merge_boxes(self): - result = self.figure_table_predictor.merge_boxes(self.doc.blocks, defaultdict(list)) + result = self.figure_table_predictor.merge_boxes( + self.doc.blocks, defaultdict(list) + ) assert list(result[0].keys()) == [0, 2, 3, 7] assert isinstance(result[0][0][0], Span) def test_get_figure_caption_distance(self): distance = FigureTablePredictions._get_object_caption_distance( - Box(l=0.2, t=0.2, w=0.1, h=0.1, page=0), Box(l=0.3, t=0.3, w=0.1, h=0.1, page=0)) + Box(l=0.2, t=0.2, w=0.1, h=0.1, page=0), + Box(l=0.3, t=0.3, w=0.1, h=0.1, page=0), + ) assert distance == 900 distance = FigureTablePredictions._get_object_caption_distance( - Box(l=0.2, t=0.2, w=0.1, h=0.1, page=0), Box(l=0.2, t=0.3, w=0.1, h=0.1, page=0)) + Box(l=0.2, t=0.2, w=0.1, h=0.1, page=0), + Box(l=0.2, t=0.3, w=0.1, h=0.1, page=0), + ) assert distance == pytest.approx(0.15) @@ -57,12 +74,17 @@ def test_generate_map_of_layout_to_tokens(self): Test that the function generates a map of layout to tokens using """ vila_caption = FigureTablePredictions._filter_span_group( - self.doc.vila_span_groups, caption_content='fig', span_group_types=['Caption']) + self.doc.vila_span_groups, + caption_content="fig", + span_group_types=["Caption"], + ) - vila_caption_dict = FigureTablePredictions._create_dict_of_pages_spans_vila(vila_caption) + vila_caption_dict = FigureTablePredictions._create_dict_of_pages_spans_vila( + vila_caption + ) result = self.figure_table_predictor.generate_map_of_layout_to_tokens( - vila_caption_dict, - defaultdict(list), defaultdict(list)) + vila_caption_dict, defaultdict(list), defaultdict(list) + ) assert list(result.keys()) == [] def test_predict_e5910c027af0ee9c1901c57f6579d903aedee7f4(self): @@ -72,71 +94,134 @@ def test_predict_e5910c027af0ee9c1901c57f6579d903aedee7f4(self): """ result = self.figure_table_predictor.predict() assert isinstance(result, dict) - assert list(result.keys()) == ['figures', 'figure_captions', 'figure_to_figure_captions', 'tables', - 'table_captions', - 'table_to_table_captions', ] - assert len(result['figures']) == 4 - assert len(result['tables']) == 4 - assert isinstance(result['figure_to_figure_captions'][0], Relation) - assert isinstance(result['table_to_table_captions'][0], Relation) - assert [figure.to_json() for figure in result['figures']] == [{'boxes': [{'height': 0.130624674787425, - 'left': 0.5021962683185254, - 'page': 0, - 'top': 0.3574526237718987, - 'width': 0.3930938321780535}]}, - {'boxes': [{'height': 0.21034525861643782, - 'left': 0.08724006952023973, - 'page': 2, - 'top': 0.09557842485832446, - 'width': 0.3754700804068372}], - 'id': 1}, - {'boxes': [{'height': 0.31222110318652835, - 'left': 0.08188235294117646, - 'page': 3, - 'top': 0.08723311954074436, - 'width': 0.37919526861851516}], - 'id': 2}, - {'boxes': [{'height': 0.3527590433756511, - 'left': 0.09958468543158637, - 'page': 7, - 'top': 0.08601251274648339, - 'width': 0.8034834020278033}], - 'id': 3}] - assert [figure_caption.to_json() for figure_caption in result['figure_captions']] == [ - {'id': 0, 'metadata': {}, 'spans': [{'end': 2057, 'start': 2034}]}, - {'id': 1, 'metadata': {}, 'spans': [{'end': 9679, 'start': 9175}]}, - {'id': 2, 'metadata': {}, 'spans': [{'end': 13875, 'start': 13822}]}, - {'id': 3, 'metadata': {}, 'spans': [{'end': 31364, 'start': 31224}]}] - - assert [table.to_json() for table in result['tables']] == [{'boxes': [{'height': 0.2796805025351168, - 'left': 0.16789371515411178, - 'page': 4, - 'top': 0.1370883614125878, - 'width': 0.6443845462175756}]}, - {'boxes': [{'height': 0.20913203075678666, - 'left': 0.1747694701151131, - 'page': 5, - 'top': 0.13721680882001164, - 'width': 0.622537251391442}], - 'id': 1}, - {'boxes': [{'height': 0.06003320096719145, - 'left': 0.15402431114047183, - 'page': 5, - 'top': 0.5840287642045454, - 'width': 0.2569979998021344}], - 'id': 2}, - {'boxes': [{'height': 0.23519277090978136, - 'left': 0.5027104296715431, - 'page': 6, - 'top': 0.27805763784081045, - 'width': 0.3950077131682751}], - 'id': 3}] - - assert [table_caption.to_json() for table_caption in result['table_captions']] == [ - {'id': 0, 'metadata': {}, 'spans': [{'end': 18359, 'start': 18198}]}, - {'id': 1, 'metadata': {}, 'spans': [{'end': 22214, 'start': 22042}]}, - {'id': 2, 'metadata': {}, 'spans': [{'end': 23502, 'start': 23400}]}, - {'id': 3, 'metadata': {}, 'spans': [{'end': 29584, 'start': 29369}]}] + assert list(result.keys()) == [ + "figures", + "figure_captions", + "figure_to_figure_captions", + "tables", + "table_captions", + "table_to_table_captions", + ] + assert len(result["figures"]) == 4 + assert len(result["tables"]) == 4 + assert isinstance(result["figure_to_figure_captions"][0], Relation) + assert isinstance(result["table_to_table_captions"][0], Relation) + assert [figure.to_json() for figure in result["figures"]] == [ + { + "boxes": [ + { + "height": 0.130624674787425, + "left": 0.5021962683185254, + "page": 0, + "top": 0.3574526237718987, + "width": 0.3930938321780535, + } + ] + }, + { + "boxes": [ + { + "height": 0.21034525861643782, + "left": 0.08724006952023973, + "page": 2, + "top": 0.09557842485832446, + "width": 0.3754700804068372, + } + ], + "id": 1, + }, + { + "boxes": [ + { + "height": 0.31222110318652835, + "left": 0.08188235294117646, + "page": 3, + "top": 0.08723311954074436, + "width": 0.37919526861851516, + } + ], + "id": 2, + }, + { + "boxes": [ + { + "height": 0.3527590433756511, + "left": 0.09958468543158637, + "page": 7, + "top": 0.08601251274648339, + "width": 0.8034834020278033, + } + ], + "id": 3, + }, + ] + assert [ + figure_caption.to_json() for figure_caption in result["figure_captions"] + ] == [ + {"id": 0, "metadata": {}, "spans": [{"end": 2057, "start": 2034}]}, + {"id": 1, "metadata": {}, "spans": [{"end": 9679, "start": 9175}]}, + {"id": 2, "metadata": {}, "spans": [{"end": 13875, "start": 13822}]}, + {"id": 3, "metadata": {}, "spans": [{"end": 31364, "start": 31224}]}, + ] + + assert [table.to_json() for table in result["tables"]] == [ + { + "boxes": [ + { + "height": 0.2796805025351168, + "left": 0.16789371515411178, + "page": 4, + "top": 0.1370883614125878, + "width": 0.6443845462175756, + } + ] + }, + { + "boxes": [ + { + "height": 0.20913203075678666, + "left": 0.1747694701151131, + "page": 5, + "top": 0.13721680882001164, + "width": 0.622537251391442, + } + ], + "id": 1, + }, + { + "boxes": [ + { + "height": 0.06003320096719145, + "left": 0.15402431114047183, + "page": 5, + "top": 0.5840287642045454, + "width": 0.2569979998021344, + } + ], + "id": 2, + }, + { + "boxes": [ + { + "height": 0.23519277090978136, + "left": 0.5027104296715431, + "page": 6, + "top": 0.27805763784081045, + "width": 0.3950077131682751, + } + ], + "id": 3, + }, + ] + + assert [ + table_caption.to_json() for table_caption in result["table_captions"] + ] == [ + {"id": 0, "metadata": {}, "spans": [{"end": 18359, "start": 18198}]}, + {"id": 1, "metadata": {}, "spans": [{"end": 22214, "start": 22042}]}, + {"id": 2, "metadata": {}, "spans": [{"end": 23502, "start": 23400}]}, + {"id": 3, "metadata": {}, "spans": [{"end": 29584, "start": 29369}]}, + ] def test_predict_2149e0c1106e6dfa36ea787167d6611cf88b69cb(self): """ @@ -146,30 +231,40 @@ def test_predict_2149e0c1106e6dfa36ea787167d6611cf88b69cb(self): self.figure_table_predictor.doc = self.doc_2 result = self.figure_table_predictor.predict() assert isinstance(result, dict) - assert list(result.keys()) == ['figures', 'figure_captions', 'figure_to_figure_captions', 'tables', - 'table_captions', - 'table_to_table_captions', ] - assert len(result['figures']) == 19 - assert len(result['tables']) == 0 - assert isinstance(result['figure_to_figure_captions'][0], Relation) - assert [figure.to_json() for figure in result['figures']] == self.figure_predictions - assert [figure_caption.to_json() for figure_caption in result['figure_captions']] == [ - {'id': 0, 'metadata': {}, 'spans': [{'end': 5253, 'start': 5019}]}, - {'id': 1, 'metadata': {}, 'spans': [{'end': 9230, 'start': 8976}]}, - {'id': 2, 'metadata': {}, 'spans': [{'end': 13164, 'start': 12935}]}, - {'id': 3, 'metadata': {}, 'spans': [{'end': 17600, 'start': 17373}]}, - {'id': 4, 'metadata': {}, 'spans': [{'end': 23624, 'start': 23205}]}, - {'id': 5, 'metadata': {}, 'spans': [{'end': 21009, 'start': 20070}]}, - {'id': 6, 'metadata': {}, 'spans': [{'end': 28975, 'start': 28838}]}, - {'id': 7, 'metadata': {}, 'spans': [{'end': 32839, 'start': 32681}]}, - {'id': 8, 'metadata': {}, 'spans': [{'end': 37061, 'start': 36394}]}, - {'id': 9, 'metadata': {}, 'spans': [{'end': 42245, 'start': 42063}]}, - {'id': 10, 'metadata': {}, 'spans': [{'end': 43512, 'start': 43418}]}, - {'id': 11, 'metadata': {}, 'spans': [{'end': 46726, 'start': 46542}]}, - {'id': 12, 'metadata': {}, 'spans': [{'end': 50359, 'start': 50192}]}, - {'id': 13, 'metadata': {}, 'spans': [{'end': 57779, 'start': 57323}]}, - {'id': 14, 'metadata': {}, 'spans': [{'end': 60918, 'start': 60838}]}, - {'id': 15, 'metadata': {}, 'spans': [{'end': 64943, 'start': 64238}]}, - {'id': 16, 'metadata': {}, 'spans': [{'end': 69170, 'start': 68548}]}, - {'id': 17, 'metadata': {}, 'spans': [{'end': 75951, 'start': 75767}]}, - {'id': 18, 'metadata': {}, 'spans': [{'end': 80129, 'start': 79561}]}] + assert list(result.keys()) == [ + "figures", + "figure_captions", + "figure_to_figure_captions", + "tables", + "table_captions", + "table_to_table_captions", + ] + assert len(result["figures"]) == 19 + assert len(result["tables"]) == 0 + assert isinstance(result["figure_to_figure_captions"][0], Relation) + assert [ + figure.to_json() for figure in result["figures"] + ] == self.figure_predictions + assert [ + figure_caption.to_json() for figure_caption in result["figure_captions"] + ] == [ + {"id": 0, "metadata": {}, "spans": [{"end": 5253, "start": 5019}]}, + {"id": 1, "metadata": {}, "spans": [{"end": 9230, "start": 8976}]}, + {"id": 2, "metadata": {}, "spans": [{"end": 13164, "start": 12935}]}, + {"id": 3, "metadata": {}, "spans": [{"end": 17600, "start": 17373}]}, + {"id": 4, "metadata": {}, "spans": [{"end": 23624, "start": 23205}]}, + {"id": 5, "metadata": {}, "spans": [{"end": 21009, "start": 20070}]}, + {"id": 6, "metadata": {}, "spans": [{"end": 28975, "start": 28838}]}, + {"id": 7, "metadata": {}, "spans": [{"end": 32839, "start": 32681}]}, + {"id": 8, "metadata": {}, "spans": [{"end": 37061, "start": 36394}]}, + {"id": 9, "metadata": {}, "spans": [{"end": 42245, "start": 42063}]}, + {"id": 10, "metadata": {}, "spans": [{"end": 43512, "start": 43418}]}, + {"id": 11, "metadata": {}, "spans": [{"end": 46726, "start": 46542}]}, + {"id": 12, "metadata": {}, "spans": [{"end": 50359, "start": 50192}]}, + {"id": 13, "metadata": {}, "spans": [{"end": 57779, "start": 57323}]}, + {"id": 14, "metadata": {}, "spans": [{"end": 60918, "start": 60838}]}, + {"id": 15, "metadata": {}, "spans": [{"end": 64943, "start": 64238}]}, + {"id": 16, "metadata": {}, "spans": [{"end": 69170, "start": 68548}]}, + {"id": 17, "metadata": {}, "spans": [{"end": 75951, "start": 75767}]}, + {"id": 18, "metadata": {}, "spans": [{"end": 80129, "start": 79561}]}, + ] From d163044e00a3f379e296b54a51e0006d3194ee5b Mon Sep 17 00:00:00 2001 From: kyleclo Date: Wed, 5 Jul 2023 20:30:28 -0700 Subject: [PATCH 17/25] reformat --- tests/test_utils/test_tools.py | 987 +++++++++++++++++++++++++-------- 1 file changed, 755 insertions(+), 232 deletions(-) diff --git a/tests/test_utils/test_tools.py b/tests/test_utils/test_tools.py index 6b39bdb1..e6c1c6e5 100644 --- a/tests/test_utils/test_tools.py +++ b/tests/test_utils/test_tools.py @@ -10,18 +10,15 @@ import unittest from mmda.types.annotation import BoxGroup, SpanGroup -from mmda.types.span import Span from mmda.types.box import Box from mmda.types.document import Document - -from mmda.utils.tools import MergeSpans -from mmda.utils.tools import box_groups_to_span_groups +from mmda.types.span import Span +from mmda.utils.tools import MergeSpans, box_groups_to_span_groups fixture_path = pathlib.Path(__file__).parent.parent / "fixtures" / "utils" class TestMergeNeighborSpans(unittest.TestCase): - def test_merge_multiple_neighbor_spans(self): spans = [Span(start=0, end=10), Span(start=11, end=20), Span(start=21, end=30)] merge_spans = MergeSpans(list_of_spans=spans, index_distance=1) @@ -54,7 +51,9 @@ def test_different_index_distances(self): def test_zero_index_distance(self): spans = [Span(start=0, end=10), Span(start=10, end=20)] - out = MergeSpans(list_of_spans=spans, index_distance=0).merge_neighbor_spans_by_symbol_distance() + out = MergeSpans( + list_of_spans=spans, index_distance=0 + ).merge_neighbor_spans_by_symbol_distance() assert len(out) == 1 assert isinstance(out[0], Span) assert out[0].start == 0 @@ -64,7 +63,7 @@ def test_handling_of_boxes(self): spans = [ Span(start=0, end=10, box=Box(l=0, t=0, w=1, h=1, page=0)), Span(start=11, end=20, box=Box(l=1, t=1, w=2, h=2, page=0)), - Span(start=21, end=150, box=Box(l=2, t=2, w=3, h=3, page=1)) + Span(start=21, end=150, box=Box(l=2, t=2, w=3, h=3, page=1)), ] merge_spans = MergeSpans(list_of_spans=spans, index_distance=1) merge_spans.merge_neighbor_spans_by_symbol_distance() @@ -84,7 +83,7 @@ def test_handling_of_boxes(self): spans = [ Span(start=0, end=10, box=Box(l=0, t=0, w=1, h=1, page=1)), Span(start=11, end=20, box=Box(l=1, t=1, w=2, h=2, page=1)), - Span(start=100, end=150, box=Box(l=2, t=2, w=3, h=3, page=1)) + Span(start=100, end=150, box=Box(l=2, t=2, w=3, h=3, page=1)), ] merge_spans = MergeSpans(list_of_spans=spans, index_distance=1) @@ -104,7 +103,7 @@ def test_handling_of_boxes(self): Span(start=0, end=10, box=Box(l=0, t=0, w=1, h=1, page=0)), Span(start=11, end=20), Span(start=21, end=150), - Span(start=155, end=200) + Span(start=155, end=200), ] merge_spans = MergeSpans(list_of_spans=spans, index_distance=1) merge_spans.merge_neighbor_spans_by_symbol_distance() @@ -126,223 +125,721 @@ def test_handling_of_boxes(self): list_of_spans_to_merge = [ - Span(start=3944, end=3948, - box=Box(l=0.19238134915568578, t=0.22752901673615306, w=0.06941334053447479, h=0.029442207414270286, - page=4)), - Span(start=3949, end=3951, - box=Box(l=0.27220460878651254, t=0.22752901673615306, w=0.03468585042904468, h=0.029442207414270286, - page=4)), - Span(start=4060, end=4063, - box=Box(l=0.4204075769894973, t=0.34144142726484455, w=0.023417310961637895, h=0.014200429984914883, - page=4)), - Span(start=4072, end=4075, - box=Box(l=0.5182742633669088, t=0.34144142726484455, w=0.029000512031393755, h=0.014200429984914883, - page=4)), - Span(start=4076, end=4083, - box=Box(l=0.5522956396696659, t=0.34144142726484455, w=0.06440764687304719, h=0.014200429984914883, - page=4)), - Span(start=4119, end=4128, - box=Box(l=0.2686971421659869, t=0.36273518298114954, w=0.08479235581478171, h=0.014200429984914883, - page=4)), - Span(start=4134, end=4144, - box=Box(l=0.40387889180816966, t=0.36273518298114954, w=0.08368776567508182, h=0.014200429984914883, - page=4)), - Span(start=4145, end=4148, - box=Box(l=0.4943548659781345, t=0.36273518298114954, w=0.042396177907390975, h=0.014200429984914883, - page=4)), - Span(start=4149, end=4162, - box=Box(l=0.5435392523804085, t=0.36273518298114954, w=0.11491754144296094, h=0.014200429984914883, - page=4)), - Span(start=4166, end=4177, - box=Box(l=0.6876581404256177, t=0.36273518298114954, w=0.09146006356715199, h=0.014200429984914883, - page=4)), - Span(start=4419, end=4427, - box=Box(l=0.2686971421659869, t=0.4479113936500019, w=0.06846450520430858, h=0.014200429984914883, - page=4)), - Span(start=4497, end=4505, - box=Box(l=0.2686971421659869, t=0.46920514936630686, w=0.06846450520430858, h=0.014200429984914883, - page=4)), - Span(start=4517, end=4520, - box=Box(l=0.42195400318507725, t=0.46920514936630686, w=0.029000512031393755, h=0.014200429984914883, - page=4)), - Span(start=4574, end=4581, - box=Box(l=0.2686971421659869, t=0.49049890508261185, w=0.07810456460532592, h=0.014200429984914883, - page=4)), - Span(start=4582, end=4587, - box=Box(l=0.35061756361754887, t=0.49049890508261185, w=0.03904224057412029, h=0.014200429984914883, - page=4)), - Span(start=4588, end=4591, - box=Box(l=0.39347566103790516, t=0.49049890508261185, w=0.023417310961637943, h=0.014200429984914883, - page=4)), - Span(start=4592, end=4601, - box=Box(l=0.4207088288457791, t=0.49049890508261185, w=0.08254300862121101, h=0.014200429984914883, - page=4)), - Span(start=4602, end=4613, - box=Box(l=0.5070676943132262, t=0.49049890508261185, w=0.09481400090042272, h=0.014200429984914883, - page=4)),] - -list_of_spans_to_merge_2 = [Span(start=30113, end=30119, - box=Box(l=0.12095229775767885, t=0.3578497466414853, w=0.05243790645011725, - h=0.014200429984914883, page=19)), - Span(start=30120, end=30124, - box=Box(l=0.17929474059091924, t=0.3578497466414853, w=0.030687522426571887, - h=0.014200429984914883, - page=19)), - Span(start=30125, end=30129, - box=Box(l=0.21799556239458678, t=0.3578497466414853, w=0.04350076804709073, - h=0.014200429984914883, page=19)), - Span(start=30130, end=30135, - box=Box(l=0.26740086682480063, t=0.3578497466414853, w=0.050208642713631964, - h=0.014200429984914883, - page=19)), - Span(start=30136, end=30141, - box=Box(l=0.32351404592155575, t=0.3578497466414853, w=0.0446254416438761, - h=0.014200429984914883, page=19)), - Span(start=30142, end=30151, - box=Box(l=0.37404402394855496, t=0.3578497466414853, w=0.0769598075514552, - h=0.014200429984914883, page=19)), - Span(start=30152, end=30155, - box=Box(l=0.4569284513402187, t=0.3578497466414853, w=0.029000512031393852, - h=0.014200429984914883, page=19)), - Span(start=30156, end=30165, - box=Box(l=0.4918334997547357, t=0.3578497466414853, w=0.0792091547450259, - h=0.014200429984914883, page=19)), - Span(start=30166, end=30175, - box=Box(l=0.5769471908828846, t=0.3578497466414853, w=0.07175819216632291, - h=0.014200429984914883, page=19)), - Span(start=30176, end=30179, - box=Box(l=0.6576023545380633, t=0.3578497466414853, w=0.03122977576787907, - h=0.014200429984914883, page=19)), - Span(start=30180, end=30184, - box=Box(l=0.6947366666890655, t=0.3578497466414853, w=0.03904224057412024, - h=0.014200429984914883, page=19)), - Span(start=30185, end=30190, - box=Box(l=0.7396834436463088, t=0.3578497466414853, w=0.05020864271363187, - h=0.014200429984914883, page=19)), - Span(start=30191, end=30193, - box=Box(l=0.7957966227430638, t=0.3578497466414853, w=0.015624929612482252, - h=0.014200429984914883, page=19)), - Span(start=30194, end=30197, - box=Box(l=0.12095229775767885, t=0.37500875791374183, w=0.024541984558423317, - h=0.014200429984914883, - page=19)), - Span(start=30198, end=30207, - box=Box(l=0.1518205712980198, t=0.37500875791374183, w=0.07695980755145514, - h=0.014200429984914883, page=19)), - Span(start=30208, end=30210, - box=Box(l=0.2351066678313926, t=0.37500875791374183, w=0.013395665875996984, - h=0.014200429984914883, - page=19)), - Span(start=30211, end=30214, - box=Box(l=0.2548286226893072, t=0.37500875791374183, w=0.02231272082193805, - h=0.014200429984914883, page=19)), - Span(start=30215, end=30217, - box=Box(l=0.283467632493163, t=0.37500875791374183, w=0.015624929612482252, - h=0.014200429984914883, page=19)), - Span(start=30218, end=30221, - box=Box(l=0.3054188510875629, t=0.37500875791374183, w=0.024541984558423317, - h=0.014200429984914883, - page=19)), - Span(start=30222, end=30229, - box=Box(l=0.33628712462790383, t=0.37500875791374183, w=0.055570925755447906, - h=0.014200429984914883, - page=19)), - Span(start=30230, end=30235, - box=Box(l=0.3981843393652693, t=0.37500875791374183, w=0.04183384110899822, - h=0.014200429984914883, page=19)), - Span(start=30236, end=30240, - box=Box(l=0.44668588822663785, t=0.37500875791374183, w=0.03570838669793504, - h=0.014200429984914883, - page=19)), - Span(start=30241, end=30244, - box=Box(l=0.4887205639064905, t=0.37500875791374183, w=0.020083457085452783, - h=0.014200429984914883, - page=19)), - Span(start=30245, end=30255, - box=Box(l=0.5151303099738609, t=0.37500875791374183, w=0.08810612623388145, - h=0.014200429984914883, page=19)), - Span(start=30256, end=30259, - box=Box(l=0.6095627251896601, t=0.37500875791374183, w=0.022312720821938, - h=0.014200429984914883, page=19)), - Span(start=30260, end=30262, - box=Box(l=0.6382017349935157, t=0.37500875791374183, w=0.015624929612482252, - h=0.014200429984914883, - page=19)), - Span(start=30263, end=30268, - box=Box(l=0.6601529535879158, t=0.37500875791374183, w=0.03958449391542752, - h=0.014200429984914883, page=19)), - Span(start=30269, end=30273, - box=Box(l=0.7098795933314969, t=0.37500875791374183, w=0.035708386697935225, - h=0.014200429984914883, - page=19)), - Span(start=30274, end=30276, - box=Box(l=0.7519142690113497, t=0.37500875791374183, w=0.013395665875997033, - h=0.014200429984914883, - page=19)), - Span(start=30277, end=30278, - box=Box(l=0.7716362238692644, t=0.37500875791374183, w=0.008917054945941066, - h=0.014200429984914883, - page=19)), - Span(start=30279, end=30281, - box=Box(l=0.7868795677971232, t=0.37500875791374183, w=0.02454198455842322, - h=0.014200429984914883, page=19)), - Span(start=30282, end=30291, - box=Box(l=0.12095229775767885, t=0.3921677691859983, w=0.08031374488472577, - h=0.014200429984914883, page=19)), - Span(start=30292, end=30296, - box=Box(l=0.2062869069137678, t=0.3921677691859983, w=0.03904224057412024, - h=0.014200429984914883, page=19)), - Span(start=30297, end=30302, - box=Box(l=0.25035001175925126, t=0.3921677691859983, w=0.050208642713631964, - h=0.014200429984914883, - page=19)), - Span(start=30303, end=30311, - box=Box(l=0.30557951874424644, t=0.3921677691859983, w=0.08143841848151108, - h=0.014200429984914883, page=19)), - Span(start=30312, end=30314, - box=Box(l=0.3920388014971207, t=0.3921677691859983, w=0.016729519752182193, - h=0.014200429984914883, page=19)), - Span(start=30315, end=30321, - box=Box(l=0.4137891855206661, t=0.3921677691859983, w=0.0535625800469026, - h=0.014200429984914883, page=19)), - Span(start=30322, end=30328, - box=Box(l=0.47237262983893197, t=0.3921677691859983, w=0.05354249658981717, - h=0.014200429984914883, page=19)), - Span(start=30329, end=30333, - box=Box(l=0.5309359907001122, t=0.3921677691859983, w=0.03681297683763493, - h=0.014200429984914883, page=19)), - Span(start=30334, end=30336, - box=Box(l=0.5727698318091105, t=0.3921677691859983, w=0.01672951975218224, - h=0.014200429984914883, page=19)), - Span(start=30337, end=30344, - box=Box(l=0.5945202158326559, t=0.3921677691859983, w=0.060230287799273016, - h=0.014200429984914883, page=19)), - Span(start=30345, end=30348, - box=Box(l=0.6597713679032922, t=0.3921677691859983, w=0.029000512031393946, - h=0.014200429984914883, page=19)), - Span(start=30349, end=30359, - box=Box(l=0.6937927442060494, t=0.3921677691859983, w=0.07834556609035141, - h=0.014200429984914883, page=19))] + Span( + start=3944, + end=3948, + box=Box( + l=0.19238134915568578, + t=0.22752901673615306, + w=0.06941334053447479, + h=0.029442207414270286, + page=4, + ), + ), + Span( + start=3949, + end=3951, + box=Box( + l=0.27220460878651254, + t=0.22752901673615306, + w=0.03468585042904468, + h=0.029442207414270286, + page=4, + ), + ), + Span( + start=4060, + end=4063, + box=Box( + l=0.4204075769894973, + t=0.34144142726484455, + w=0.023417310961637895, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4072, + end=4075, + box=Box( + l=0.5182742633669088, + t=0.34144142726484455, + w=0.029000512031393755, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4076, + end=4083, + box=Box( + l=0.5522956396696659, + t=0.34144142726484455, + w=0.06440764687304719, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4119, + end=4128, + box=Box( + l=0.2686971421659869, + t=0.36273518298114954, + w=0.08479235581478171, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4134, + end=4144, + box=Box( + l=0.40387889180816966, + t=0.36273518298114954, + w=0.08368776567508182, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4145, + end=4148, + box=Box( + l=0.4943548659781345, + t=0.36273518298114954, + w=0.042396177907390975, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4149, + end=4162, + box=Box( + l=0.5435392523804085, + t=0.36273518298114954, + w=0.11491754144296094, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4166, + end=4177, + box=Box( + l=0.6876581404256177, + t=0.36273518298114954, + w=0.09146006356715199, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4419, + end=4427, + box=Box( + l=0.2686971421659869, + t=0.4479113936500019, + w=0.06846450520430858, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4497, + end=4505, + box=Box( + l=0.2686971421659869, + t=0.46920514936630686, + w=0.06846450520430858, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4517, + end=4520, + box=Box( + l=0.42195400318507725, + t=0.46920514936630686, + w=0.029000512031393755, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4574, + end=4581, + box=Box( + l=0.2686971421659869, + t=0.49049890508261185, + w=0.07810456460532592, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4582, + end=4587, + box=Box( + l=0.35061756361754887, + t=0.49049890508261185, + w=0.03904224057412029, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4588, + end=4591, + box=Box( + l=0.39347566103790516, + t=0.49049890508261185, + w=0.023417310961637943, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4592, + end=4601, + box=Box( + l=0.4207088288457791, + t=0.49049890508261185, + w=0.08254300862121101, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4602, + end=4613, + box=Box( + l=0.5070676943132262, + t=0.49049890508261185, + w=0.09481400090042272, + h=0.014200429984914883, + page=4, + ), + ), +] + +list_of_spans_to_merge_2 = [ + Span( + start=30113, + end=30119, + box=Box( + l=0.12095229775767885, + t=0.3578497466414853, + w=0.05243790645011725, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30120, + end=30124, + box=Box( + l=0.17929474059091924, + t=0.3578497466414853, + w=0.030687522426571887, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30125, + end=30129, + box=Box( + l=0.21799556239458678, + t=0.3578497466414853, + w=0.04350076804709073, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30130, + end=30135, + box=Box( + l=0.26740086682480063, + t=0.3578497466414853, + w=0.050208642713631964, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30136, + end=30141, + box=Box( + l=0.32351404592155575, + t=0.3578497466414853, + w=0.0446254416438761, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30142, + end=30151, + box=Box( + l=0.37404402394855496, + t=0.3578497466414853, + w=0.0769598075514552, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30152, + end=30155, + box=Box( + l=0.4569284513402187, + t=0.3578497466414853, + w=0.029000512031393852, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30156, + end=30165, + box=Box( + l=0.4918334997547357, + t=0.3578497466414853, + w=0.0792091547450259, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30166, + end=30175, + box=Box( + l=0.5769471908828846, + t=0.3578497466414853, + w=0.07175819216632291, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30176, + end=30179, + box=Box( + l=0.6576023545380633, + t=0.3578497466414853, + w=0.03122977576787907, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30180, + end=30184, + box=Box( + l=0.6947366666890655, + t=0.3578497466414853, + w=0.03904224057412024, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30185, + end=30190, + box=Box( + l=0.7396834436463088, + t=0.3578497466414853, + w=0.05020864271363187, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30191, + end=30193, + box=Box( + l=0.7957966227430638, + t=0.3578497466414853, + w=0.015624929612482252, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30194, + end=30197, + box=Box( + l=0.12095229775767885, + t=0.37500875791374183, + w=0.024541984558423317, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30198, + end=30207, + box=Box( + l=0.1518205712980198, + t=0.37500875791374183, + w=0.07695980755145514, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30208, + end=30210, + box=Box( + l=0.2351066678313926, + t=0.37500875791374183, + w=0.013395665875996984, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30211, + end=30214, + box=Box( + l=0.2548286226893072, + t=0.37500875791374183, + w=0.02231272082193805, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30215, + end=30217, + box=Box( + l=0.283467632493163, + t=0.37500875791374183, + w=0.015624929612482252, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30218, + end=30221, + box=Box( + l=0.3054188510875629, + t=0.37500875791374183, + w=0.024541984558423317, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30222, + end=30229, + box=Box( + l=0.33628712462790383, + t=0.37500875791374183, + w=0.055570925755447906, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30230, + end=30235, + box=Box( + l=0.3981843393652693, + t=0.37500875791374183, + w=0.04183384110899822, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30236, + end=30240, + box=Box( + l=0.44668588822663785, + t=0.37500875791374183, + w=0.03570838669793504, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30241, + end=30244, + box=Box( + l=0.4887205639064905, + t=0.37500875791374183, + w=0.020083457085452783, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30245, + end=30255, + box=Box( + l=0.5151303099738609, + t=0.37500875791374183, + w=0.08810612623388145, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30256, + end=30259, + box=Box( + l=0.6095627251896601, + t=0.37500875791374183, + w=0.022312720821938, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30260, + end=30262, + box=Box( + l=0.6382017349935157, + t=0.37500875791374183, + w=0.015624929612482252, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30263, + end=30268, + box=Box( + l=0.6601529535879158, + t=0.37500875791374183, + w=0.03958449391542752, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30269, + end=30273, + box=Box( + l=0.7098795933314969, + t=0.37500875791374183, + w=0.035708386697935225, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30274, + end=30276, + box=Box( + l=0.7519142690113497, + t=0.37500875791374183, + w=0.013395665875997033, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30277, + end=30278, + box=Box( + l=0.7716362238692644, + t=0.37500875791374183, + w=0.008917054945941066, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30279, + end=30281, + box=Box( + l=0.7868795677971232, + t=0.37500875791374183, + w=0.02454198455842322, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30282, + end=30291, + box=Box( + l=0.12095229775767885, + t=0.3921677691859983, + w=0.08031374488472577, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30292, + end=30296, + box=Box( + l=0.2062869069137678, + t=0.3921677691859983, + w=0.03904224057412024, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30297, + end=30302, + box=Box( + l=0.25035001175925126, + t=0.3921677691859983, + w=0.050208642713631964, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30303, + end=30311, + box=Box( + l=0.30557951874424644, + t=0.3921677691859983, + w=0.08143841848151108, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30312, + end=30314, + box=Box( + l=0.3920388014971207, + t=0.3921677691859983, + w=0.016729519752182193, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30315, + end=30321, + box=Box( + l=0.4137891855206661, + t=0.3921677691859983, + w=0.0535625800469026, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30322, + end=30328, + box=Box( + l=0.47237262983893197, + t=0.3921677691859983, + w=0.05354249658981717, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30329, + end=30333, + box=Box( + l=0.5309359907001122, + t=0.3921677691859983, + w=0.03681297683763493, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30334, + end=30336, + box=Box( + l=0.5727698318091105, + t=0.3921677691859983, + w=0.01672951975218224, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30337, + end=30344, + box=Box( + l=0.5945202158326559, + t=0.3921677691859983, + w=0.060230287799273016, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30345, + end=30348, + box=Box( + l=0.6597713679032922, + t=0.3921677691859983, + w=0.029000512031393946, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30349, + end=30359, + box=Box( + l=0.6937927442060494, + t=0.3921677691859983, + w=0.07834556609035141, + h=0.014200429984914883, + page=19, + ), + ), +] def test_merge_spans(): - assert len(list_of_spans_to_merge) == (len(MergeSpans(list_of_spans_to_merge, 0, 0) - .merge_neighbor_spans_by_box_coordinate())) + assert len(list_of_spans_to_merge) == ( + len( + MergeSpans( + list_of_spans_to_merge, 0, 0 + ).merge_neighbor_spans_by_box_coordinate() + ) + ) - assert 4 == len(MergeSpans(list_of_spans_to_merge, 0.04387334, 0.01421097).merge_neighbor_spans_by_box_coordinate()) + assert 4 == len( + MergeSpans( + list_of_spans_to_merge, 0.04387334, 0.01421097 + ).merge_neighbor_spans_by_box_coordinate() + ) merge_spans = MergeSpans(list_of_spans_to_merge_2, 0.04387334, 0.01421097) assert 1 == len(merge_spans.merge_neighbor_spans_by_box_coordinate()) - assert [30113, 30359] == [merge_spans.merge_neighbor_spans_by_box_coordinate()[0].start, merge_spans.merge_neighbor_spans_by_box_coordinate()[0].end] + assert [30113, 30359] == [ + merge_spans.merge_neighbor_spans_by_box_coordinate()[0].start, + merge_spans.merge_neighbor_spans_by_box_coordinate()[0].end, + ] def test_merge_neighbor_spans_by_symbol_distance(): - assert 7 == (len(MergeSpans(list_of_spans_to_merge, index_distance=10) - .merge_neighbor_spans_by_symbol_distance())) - + assert 7 == ( + len( + MergeSpans( + list_of_spans_to_merge, index_distance=10 + ).merge_neighbor_spans_by_symbol_distance() + ) + ) - assert 10 == len(MergeSpans(list_of_spans_to_merge, index_distance=1).merge_neighbor_spans_by_symbol_distance()) + assert 10 == len( + MergeSpans( + list_of_spans_to_merge, index_distance=1 + ).merge_neighbor_spans_by_symbol_distance() + ) list_of_spans_to_merge_2 = [ Span(start=1, end=3, box=Box(l=0.1, t=0.2, w=0.2, h=0.2, page=11)), @@ -370,30 +867,41 @@ def test_from_span_groups_with_box_groups(): list_of_spans_to_merge_in_span_group_format.append( SpanGroup( spans=[Span(start=span.start, end=span.end)], - box_group=BoxGroup(boxes=[span.box]) + box_group=BoxGroup(boxes=[span.box]), ) ) - assert 7 == (len(MergeSpans.from_span_groups_with_box_groups( - list_of_spans_to_merge_in_span_group_format, - index_distance=10).merge_neighbor_spans_by_symbol_distance()) - ) + assert 7 == ( + len( + MergeSpans.from_span_groups_with_box_groups( + list_of_spans_to_merge_in_span_group_format, index_distance=10 + ).merge_neighbor_spans_by_symbol_distance() + ) + ) - assert len(list_of_spans_to_merge) == (len(MergeSpans.from_span_groups_with_box_groups( - list_of_spans_to_merge_in_span_group_format, - 0, - 0).merge_neighbor_spans_by_box_coordinate())) + assert len(list_of_spans_to_merge) == ( + len( + MergeSpans.from_span_groups_with_box_groups( + list_of_spans_to_merge_in_span_group_format, 0, 0 + ).merge_neighbor_spans_by_box_coordinate() + ) + ) def test_box_groups_to_span_groups(): # basic doc annotated with pages and tokens, from pdfplumber parser split at punctuation - with open(fixture_path / "20fdafb68d0e69d193527a9a1cbe64e7e69a3798__pdfplumber_doc.json", "r") as f: + with open( + fixture_path / "20fdafb68d0e69d193527a9a1cbe64e7e69a3798__pdfplumber_doc.json", + "r", + ) as f: raw_json = f.read() fixture_doc_json = json.loads(raw_json) doc = Document.from_json(fixture_doc_json) # boxes drawn neatly around bib entries - with open(fixture_path / "20fdafb68d0e69d193527a9a1cbe64e7e69a3798__bib_entries.json", "r") as f: + with open( + fixture_path / "20fdafb68d0e69d193527a9a1cbe64e7e69a3798__bib_entries.json", "r" + ) as f: raw_json = f.read() fixture_bib_entries_json = json.loads(raw_json)["bib_entries"] @@ -404,15 +912,28 @@ def test_box_groups_to_span_groups(): # generate span_groups with different settings overlap_span_groups = box_groups_to_span_groups(box_groups, doc, center=False) - overlap_at_token_center_span_groups = box_groups_to_span_groups(box_groups, doc, center=True) - overlap_at_token_center_span_groups_x_padded = box_groups_to_span_groups(box_groups, doc, center=True, pad_x=True) - - assert (len(box_groups) == len(overlap_span_groups) == len(overlap_at_token_center_span_groups) == len(overlap_at_token_center_span_groups_x_padded)) + overlap_at_token_center_span_groups = box_groups_to_span_groups( + box_groups, doc, center=True + ) + overlap_at_token_center_span_groups_x_padded = box_groups_to_span_groups( + box_groups, doc, center=True, pad_x=True + ) + + assert ( + len(box_groups) + == len(overlap_span_groups) + == len(overlap_at_token_center_span_groups) + == len(overlap_at_token_center_span_groups_x_padded) + ) # annotate all onto doc to extract texts: doc.annotate(overlap_span_groups=overlap_span_groups) - doc.annotate(overlap_at_token_center_span_groups=overlap_at_token_center_span_groups) - doc.annotate(overlap_at_token_center_span_groups_x_padded=overlap_at_token_center_span_groups_x_padded) + doc.annotate( + overlap_at_token_center_span_groups=overlap_at_token_center_span_groups + ) + doc.annotate( + overlap_at_token_center_span_groups_x_padded=overlap_at_token_center_span_groups_x_padded + ) # when center=False, any token overlap with BoxGroup becomes part of the SpanGroup # in this example, tokens from bib entry '29 and '31' overlap with the box drawn neatly around '30' @@ -444,4 +965,6 @@ def test_box_groups_to_span_groups(): assert doc.overlap_at_token_center_span_groups_x_padded[6].text.startswith("[6]") # original box_group boxes are saved - assert all([sg.box_group is not None for sg in doc.overlap_at_token_center_span_groups]) + assert all( + [sg.box_group is not None for sg in doc.overlap_at_token_center_span_groups] + ) From 524e260fd4ef5fd6fe4ba69cfde1f94d558bf9a9 Mon Sep 17 00:00:00 2001 From: kyleclo Date: Wed, 5 Jul 2023 20:32:51 -0700 Subject: [PATCH 18/25] move MergeSpan into figure table heuristic predictor --- .../figure_table_predictors.py | 542 ++++++++++++++---- src/mmda/utils/tools.py | 233 ++------ 2 files changed, 469 insertions(+), 306 deletions(-) diff --git a/src/mmda/predictors/heuristic_predictors/figure_table_predictors.py b/src/mmda/predictors/heuristic_predictors/figure_table_predictors.py index f0c47c85..c90e4585 100644 --- a/src/mmda/predictors/heuristic_predictors/figure_table_predictors.py +++ b/src/mmda/predictors/heuristic_predictors/figure_table_predictors.py @@ -1,31 +1,211 @@ from collections import defaultdict +from itertools import groupby +from typing import Dict, List, Tuple, Union -from typing import List, Dict, Tuple, Union import numpy as np from scipy.optimize import linear_sum_assignment - from tqdm import tqdm from ai2_internal import api -from mmda.predictors.base_predictors.base_heuristic_predictor import BaseHeuristicPredictor -from mmda.types import SpanGroup, BoxGroup +from ai2_internal.api import Relation +from mmda.predictors.base_predictors.base_heuristic_predictor import ( + BaseHeuristicPredictor, +) +from mmda.types import BoxGroup, SpanGroup from mmda.types.document import Document from mmda.types.span import Span -from mmda.utils.tools import MergeSpans -from ai2_internal.api import Relation -class FigureTablePredictions(BaseHeuristicPredictor): - """Class for creating a map of figure boxes to figure captions +class MergeSpans: + """ + Given w=width and h=height merge neighboring spans which are w, h or less apart or by merging neighboring spans + which are index distance apart + Inspired by https://leetcode.com/problems/merge-intervals/ """ - REQUIRED_DOCUMENT_FIELDS = ['pages', 'tokens', 'vila_span_groups', 'blocks', ] + + def __init__( + self, + list_of_spans: List["Span"], + w: float = 0, + h: float = 0, + index_distance: int = 1, + ) -> None: + """ + Args + w (float): The input width between boxes to merge + h (float): The input height between the boxes to merge + index_distance (int): Distance between the spans + """ + self.list_of_spans = list_of_spans + self.w = w + self.h = h + self.graph = defaultdict(list) + self.index_distance = index_distance + + @classmethod + def from_span_groups_with_box_groups( + cls, + span_groups: List["SpanGroup"], + w: float = 0, + h: float = 0, + index_distance: int = 1, + ) -> MergeSpans: + # Convert SpanGroups with single box_group box into SpanGroups with span.box + spans_with_boxes = [] + for sg in span_groups: + assert len(sg.spans) == len( + sg.box_group.boxes + ), "Unequal number of spans and boxes for SpanGroup" + for span, box in zip(sg.spans, sg.box_group.boxes): + spans_with_boxes.append(Span(start=span.start, end=span.end, box=box)) + return cls(spans_with_boxes, w, h, index_distance) + + def build_graph_index_overlap(self): + """ + Build graph, each node is represented by (start, end) of tuple, with the list of spans. Spans are considered + overlapping if they are index_distance apart + """ + starts_matrix = np.full( + (len(self.list_of_spans), len(self.list_of_spans)), + [span.start for span in self.list_of_spans], + ) + ends_matrix = np.full( + (len(self.list_of_spans), len(self.list_of_spans)), + [span.end for span in self.list_of_spans], + ) + + starts_minus_ends = np.abs(starts_matrix - ends_matrix.T) + ends_minus_starts = np.abs(ends_matrix - starts_matrix.T) + are_neighboring_spans = ( + np.minimum(starts_minus_ends, ends_minus_starts) <= self.index_distance + ) + neighboring_spans = np.transpose(are_neighboring_spans.nonzero()) + + if len(neighboring_spans) > 0: + neighboring_spans_no_dupes = neighboring_spans[ + np.where(neighboring_spans[:, 1] < neighboring_spans[:, 0]) + ] + + for j, i in neighboring_spans_no_dupes: + span_i = self.list_of_spans[i] + span_j = self.list_of_spans[j] + self.graph[span_i.start, span_i.end].append(span_j) + self.graph[span_j.start, span_j.end].append(span_i) + + def build_graph_box_overlap(self): + """ + Build graph, each node is represented by (start, end) of tuple, with the list of spans with overlapping + boxes given, w, h + """ + for i, span_i in enumerate(self.list_of_spans): + assert hasattr(span_i, "box"), "Missing attribute box in a span" + for j in range(i + 1, len(self.list_of_spans)): + assert hasattr( + self.list_of_spans[j], "box" + ), "Missing attribute box in a span" + if span_i.box.is_overlap(self.list_of_spans[j].box, self.w, self.h): + self.graph[span_i.start, span_i.end].append(self.list_of_spans[j]) + self.graph[ + self.list_of_spans[j].start, self.list_of_spans[j].end + ].append(span_i) + + # gets the connected components of the boxes overlap graph. + def get_components(self): + """ + Groups connected graph nodes into dictionary list + """ + visited = set() + comp_number = 0 + nodes_in_comp = defaultdict(list) + + def mark_component_dfs(start): + stack = [start] + while stack: + span = stack.pop() + node = span.start, span.end + if node not in visited: + visited.add(node) + nodes_in_comp[comp_number].append(span) + stack.extend(self.graph[node]) + + # mark all nodes in the same connected component with the same integer. + for span in self.list_of_spans: + center = span.start, span.end + if center not in visited: + mark_component_dfs(span) + comp_number += 1 + + return nodes_in_comp, comp_number + + def merge_neighbor_spans_by_symbol_distance(self): + """ + For each of the lists of the connected nodes determined by index distance between the spans, + merge boxes and find, min, max of the index + """ + return self.build_merged_spans_from_connected_components(index=True) + + def merge_neighbor_spans_by_box_coordinate(self): + """ + For each of the lists of the connected nodes determined by distance between the boxes, + merge boxes and find, min, max of the index + """ + return self.build_merged_spans_from_connected_components(index=False) + + def build_merged_spans_from_connected_components(self, index): + """ + For each of the lists of the connected nodes determined by symbol distance or box distance, + merge boxes and find, min, max of the index + """ + if index: + self.build_graph_index_overlap() + else: + self.build_graph_box_overlap() + + nodes_in_comp, number_of_comps = self.get_components() + + # all intervals in each connected component must be merged. + merged_spans = [] + for comp in range(number_of_comps): + if nodes_in_comp[comp]: + spans_by_page: Dict[any, List[Span]] = defaultdict(list) + for pg, page_spans in groupby( + nodes_in_comp[comp], + lambda s: s.box.page if s.box is not None else None, + ): + for span in page_spans: + spans_by_page[pg].append(span) + for page_spans in spans_by_page.values(): + merged_box = Box.small_boxes_to_big_box( + [span.box for span in page_spans] + ) + merged_spans.append( + Span( + start=min([span.start for span in page_spans]), + end=max([span.end for span in page_spans]), + box=merged_box, + ) + ) + return merged_spans + + +class FigureTablePredictions(BaseHeuristicPredictor): + """Class for creating a map of figure boxes to figure captions""" + + REQUIRED_DOCUMENT_FIELDS = [ + "pages", + "tokens", + "vila_span_groups", + "blocks", + ] def __init__(self, document: Document) -> None: self.doc = document self.vila_caption_dict = None self.vila_spans_all_dict = None self.width_heights_dict = None - self.w_avg, self.h_avg = FigureTablePredictions.get_avg_w_h_of_tokens(self.doc.tokens) + self.w_avg, self.h_avg = FigureTablePredictions.get_avg_w_h_of_tokens( + self.doc.tokens + ) # Parameteer for the fraction of the tokens classified as non-caption that are probably caption in same # Layoutparser span group self.FRACTION_OF_MISCLASSIFIED_VILA_CAPTION_TOKENS = 0.3 @@ -50,12 +230,18 @@ def get_avg_w_h_of_tokens(tokens) -> Tuple[float, float]: """ Get the average width and height of tokens """ - return np.average([[span.box.w, span.box.h] for token in tokens - for span in token.spans], axis=0) + return np.average( + [[span.box.w, span.box.h] for token in tokens for span in token.spans], + axis=0, + ) @staticmethod - def _create_dict_of_pages_spans_layoutparser(layoutparser_span_groups, types: List[str] = [], starts_with: str = '', - negation: bool = False) -> Dict[int, List[SpanGroup]]: + def _create_dict_of_pages_spans_layoutparser( + layoutparser_span_groups, + types: List[str] = [], + starts_with: str = "", + negation: bool = False, + ) -> Dict[int, List[SpanGroup]]: """ Create a dictionary of page number to list of spans, filtering or negating to the types and starts_with """ @@ -63,7 +249,9 @@ def _create_dict_of_pages_spans_layoutparser(layoutparser_span_groups, types: Li for span_group in layoutparser_span_groups: if not types or span_group.box_group.type in types: if negation: - starts_with_bool = not span_group.text.lower().startswith(starts_with) + starts_with_bool = not span_group.text.lower().startswith( + starts_with + ) else: starts_with_bool = span_group.text.lower().startswith(starts_with) @@ -72,13 +260,16 @@ def _create_dict_of_pages_spans_layoutparser(layoutparser_span_groups, types: Li # Creating unique start, end of spans used as a key for merging boxes box_api = api.Box.from_mmda(box) if span_group.spans and len(span_group.spans) == 1: - start, end = span_group.spans[0].start, span_group.spans[0].end + start, end = ( + span_group.spans[0].start, + span_group.spans[0].end, + ) else: start, end = -9999, -9999 - created_span = api.Span(start=start, - end=end, - box=box_api).to_mmda() + created_span = api.Span( + start=start, end=end, box=box_api + ).to_mmda() created_span.span_id = span_group.id created_span.box_group_type = span_group.box_group.type @@ -86,19 +277,19 @@ def _create_dict_of_pages_spans_layoutparser(layoutparser_span_groups, types: Li # Bring in the boxes from the span groups for span in span_group.spans: box_api = api.Box.from_mmda(span.box) - created_span = api.Span(start=span.start, - end=span.end, - box=box_api).to_mmda() + created_span = api.Span( + start=span.start, end=span.end, box=box_api + ).to_mmda() # Note that hash output is changing everytime it is called - created_span.span_id = f'LP_span_group_{span.box.page}_{len(span_map[span.box.page])}' + created_span.span_id = f"LP_span_group_{span.box.page}_{len(span_map[span.box.page])}" created_span.box_group_type = span_group.box_group.type span_map[span.box.page].append(created_span) return span_map @staticmethod def generate_map_of_layout_to_tokens( - vila_dict, layout_parser_overlap, dict_of_pages_layoutparser, - key='caption') -> Dict[int, Dict]: + vila_dict, layout_parser_overlap, dict_of_pages_layoutparser, key="caption" + ) -> Dict[int, Dict]: """ Generate a map of layoutparser entries to the list of vila tokens with the type = key vs type != key """ @@ -106,14 +297,17 @@ def generate_map_of_layout_to_tokens( for span in vila_dict[page]: for layout_span in dict_of_pages_layoutparser.get(page, []): if span.box.is_overlap(layout_span.box): - id_dict = layout_parser_overlap.get(layout_span.span_id, {'caption': [], 'non_caption': []}) + id_dict = layout_parser_overlap.get( + layout_span.span_id, {"caption": [], "non_caption": []} + ) id_dict[key].append(span.span_id) layout_parser_overlap[layout_span.span_id] = id_dict return layout_parser_overlap @staticmethod def generate_map_of_layout_to_tokens_for_page( - vila_list: List, layout_parser_list: List, key='caption') -> Dict[int, Dict]: + vila_list: List, layout_parser_list: List, key="caption" + ) -> Dict[int, Dict]: """ Generate a map of layoutparser tokens ids to the count of vila tokens with the type = key """ @@ -121,16 +315,21 @@ def generate_map_of_layout_to_tokens_for_page( for span in vila_list: for layout_span in layout_parser_list: if span.box.is_overlap(layout_span.box): - id_dict = layout_parser_overlap.get(layout_span.span_id, {'caption': [], 'non_caption': []}) + id_dict = layout_parser_overlap.get( + layout_span.span_id, {"caption": [], "non_caption": []} + ) if span.type.lower() == key: id_dict[key].append(span.span_id) else: - id_dict['non_caption'].append(span.span_id) + id_dict["non_caption"].append(span.span_id) layout_parser_overlap[layout_span.span_id] = id_dict return layout_parser_overlap - def update_vila_caption_dict(self, vila_caption_dict: Dict[int, List[Span]], - vila_non_caption_dict: Dict[int, List[Span]]) -> Dict[int, List[Span]]: + def update_vila_caption_dict( + self, + vila_caption_dict: Dict[int, List[Span]], + vila_non_caption_dict: Dict[int, List[Span]], + ) -> Dict[int, List[Span]]: """ Update the vila caption dict to cast tokens that are misclassified as no captions in ths same LayoutParser region @@ -138,18 +337,25 @@ def update_vila_caption_dict(self, vila_caption_dict: Dict[int, List[Span]], layout_parser_overlap = defaultdict(dict) # Build overlap map between layoutparser and caption tokens span_map = FigureTablePredictions._create_dict_of_pages_spans_layoutparser( - self.doc.blocks) + self.doc.blocks + ) layout_parser_overlap = FigureTablePredictions.generate_map_of_layout_to_tokens( - vila_caption_dict, layout_parser_overlap, span_map) + vila_caption_dict, layout_parser_overlap, span_map + ) # Build overlap map between layoutparser and non-caption tokens layout_parser_overlap = FigureTablePredictions.generate_map_of_layout_to_tokens( - vila_non_caption_dict, layout_parser_overlap, span_map, key='non_caption') + vila_non_caption_dict, layout_parser_overlap, span_map, key="non_caption" + ) for key, value in layout_parser_overlap.items(): - caption_token_fraction = len(value['caption']) / (len(value['caption']) + len(value['non_caption'])) - if ((1.0 > caption_token_fraction) and - (caption_token_fraction > self.FRACTION_OF_MISCLASSIFIED_VILA_CAPTION_TOKENS)): - for span_id in layout_parser_overlap[key]['non_caption']: + caption_token_fraction = len(value["caption"]) / ( + len(value["caption"]) + len(value["non_caption"]) + ) + if (1.0 > caption_token_fraction) and ( + caption_token_fraction + > self.FRACTION_OF_MISCLASSIFIED_VILA_CAPTION_TOKENS + ): + for span_id in layout_parser_overlap[key]["non_caption"]: for page, vila_span in vila_non_caption_dict.items(): for entry in vila_span: if entry.span_id == span_id: @@ -158,14 +364,18 @@ def update_vila_caption_dict(self, vila_caption_dict: Dict[int, List[Span]], return vila_caption_dict @staticmethod - def _filter_span_group(vila_span_groups: List[api.SpanGroup], caption_content: str, span_group_types: List[str], - negation=False) -> List[api.SpanGroup]: + def _filter_span_group( + vila_span_groups: List[api.SpanGroup], + caption_content: str, + span_group_types: List[str], + negation=False, + ) -> List[api.SpanGroup]: """ Helper function which filters out span groups based on the caption content and span group type """ result = [] for span_group in vila_span_groups: - if span_group.text.replace(' ', '').lower().startswith(caption_content): + if span_group.text.replace(" ", "").lower().startswith(caption_content): if span_group.type in span_group_types and not negation: result.append(span_group) elif negation and span_group.type not in span_group_types: @@ -173,7 +383,8 @@ def _filter_span_group(vila_span_groups: List[api.SpanGroup], caption_content: s return result def merge_vila_token_spans( - self, caption_content: str = 'fig', span_group_type: List[str] = ['Caption']) -> Dict[int, List[api.Box]]: + self, caption_content: str = "fig", span_group_type: List[str] = ["Caption"] + ) -> Dict[int, List[api.Box]]: """ Merging spanGroups Args: @@ -183,7 +394,10 @@ def merge_vila_token_spans( Returns: Dictionary page -> List of merged boxes """ vila_span_groups_filtered = FigureTablePredictions._filter_span_group( - self.doc.vila_span_groups, caption_content=caption_content, span_group_types=span_group_type) + self.doc.vila_span_groups, + caption_content=caption_content, + span_group_types=span_group_type, + ) vila_caption_dict = defaultdict(list) for entry_caption in vila_span_groups_filtered: @@ -194,11 +408,12 @@ def merge_vila_token_spans( merged_boxes_list = defaultdict(list) for page, list_of_boxes in vila_caption_dict.items(): # Merge spans if they are sufficiently close to each other - merged_boxes_list[page] = MergeSpans(list_of_spans=list_of_boxes, w=self.w_avg * 1.5, - h=self.h_avg * 1).merge_neighbor_spans_by_box_coordinate() + merged_boxes_list[page] = MergeSpans( + list_of_spans=list_of_boxes, w=self.w_avg * 1.5, h=self.h_avg * 1 + ).merge_neighbor_spans_by_box_coordinate() return merged_boxes_list - def _cast_to_caption_vila_tokens(self, caption_content='fig'): + def _cast_to_caption_vila_tokens(self, caption_content="fig"): """ Heuristic logic for fixing miss classified tokens as non caption. By checking layoutparser box predictions and tokens which belong to them, I cast the rest of the tokens to caption category. @@ -210,19 +425,35 @@ def _cast_to_caption_vila_tokens(self, caption_content='fig'): # First let's go over all the tokens which are labeled as caption and find the LayoutParser SpanGroups which # they overlap with vila_caption = FigureTablePredictions._filter_span_group( - self.doc.vila_span_groups, caption_content=caption_content, span_group_types=['Caption']) + self.doc.vila_span_groups, + caption_content=caption_content, + span_group_types=["Caption"], + ) - self.vila_caption_dict = FigureTablePredictions._create_dict_of_pages_spans_vila(vila_caption) + self.vila_caption_dict = ( + FigureTablePredictions._create_dict_of_pages_spans_vila(vila_caption) + ) vila_non_caption = FigureTablePredictions._filter_span_group( - self.doc.vila_span_groups, caption_content='', span_group_types=['Caption'], negation=True) - - vila_non_caption_dict = FigureTablePredictions._create_dict_of_pages_spans_vila(vila_non_caption) - return self.update_vila_caption_dict(self.vila_caption_dict, vila_non_caption_dict) - - def merge_boxes(self, layoutparser_span_groups: List[api.SpanGroup], - merged_boxes_vila_dict: Dict[int, List[api.Box]] = None, - types: List[str] = ['Figure']) -> Dict[int, List[api.Box]]: + self.doc.vila_span_groups, + caption_content="", + span_group_types=["Caption"], + negation=True, + ) + + vila_non_caption_dict = FigureTablePredictions._create_dict_of_pages_spans_vila( + vila_non_caption + ) + return self.update_vila_caption_dict( + self.vila_caption_dict, vila_non_caption_dict + ) + + def merge_boxes( + self, + layoutparser_span_groups: List[api.SpanGroup], + merged_boxes_vila_dict: Dict[int, List[api.Box]] = None, + types: List[str] = ["Figure"], + ) -> Dict[int, List[api.Box]]: """ Merges overlapping boxes. Vila caption predictions is more consistent than layout parser prediction, thus we check the number of items after the merge with the number of caption boxes. @@ -241,22 +472,40 @@ def merge_boxes(self, layoutparser_span_groups: List[api.SpanGroup], merged_boxes_vila_dict_left = defaultdict(list) merged_boxes_map = defaultdict(list) span_map = FigureTablePredictions._create_dict_of_pages_spans_layoutparser( - layoutparser_span_groups, types=types) + layoutparser_span_groups, types=types + ) for page, span_list in span_map.items(): # Adding vila spans to the layout parser list of the spans if merged_boxes_vila_dict[page]: span_list.extend(merged_boxes_vila_dict_left[page]) - merged_spans = (MergeSpans(span_list, w=self.w_avg * 0.5, h=self.h_avg * 1.0) - .merge_neighbor_spans_by_box_coordinate()) + merged_spans = MergeSpans( + span_list, w=self.w_avg * 0.5, h=self.h_avg * 1.0 + ).merge_neighbor_spans_by_box_coordinate() # Filtering out vila spans (not merged) - if len(span_list) != len(merged_spans) and merged_boxes_vila_dict and merged_boxes_vila_dict[page]: - merged_spans = [merged_span for merged_span in merged_spans if not any( - vila_span.box.to_json() == merged_span.box.to_json() for vila_span in merged_boxes_vila_dict[page])] - - merged_boxes_vila_dict_left[page] = [vila_span for vila_span in merged_boxes_vila_dict[page] if any( - vila_span.box.to_json() == merged_span.box.to_json() for merged_span in merged_spans)] + if ( + len(span_list) != len(merged_spans) + and merged_boxes_vila_dict + and merged_boxes_vila_dict[page] + ): + merged_spans = [ + merged_span + for merged_span in merged_spans + if not any( + vila_span.box.to_json() == merged_span.box.to_json() + for vila_span in merged_boxes_vila_dict[page] + ) + ] + + merged_boxes_vila_dict_left[page] = [ + vila_span + for vila_span in merged_boxes_vila_dict[page] + if any( + vila_span.box.to_json() == merged_span.box.to_json() + for merged_span in merged_spans + ) + ] if merged_boxes_vila_dict_left[page]: merged_boxes_vila_dict[page] = merged_boxes_vila_dict_left[page] @@ -265,7 +514,9 @@ def merge_boxes(self, layoutparser_span_groups: List[api.SpanGroup], return merged_boxes_map, merged_boxes_vila_dict @staticmethod - def _get_object_caption_distance(figure_box: api.Box, caption_box: api.Box) -> float: + def _get_object_caption_distance( + figure_box: api.Box, caption_box: api.Box + ) -> float: """ Return 900.0 if left point of figure, caption is offset more than 10% Otherwise returns distance middle of the figure box and caption box @@ -282,24 +533,24 @@ def _get_object_caption_distance(figure_box: api.Box, caption_box: api.Box) -> f return t_cap - t_fig - def get_layout_span_groups_starts_with(self, caption_content: str = 'fig', vila_spans: dict = None): - """ - - """ + def get_layout_span_groups_starts_with( + self, caption_content: str = "fig", vila_spans: dict = None + ): + """ """ spans_to_merge_dict = defaultdict(list) - self.vila_caption_dict = self._cast_to_caption_vila_tokens(caption_content=caption_content) + self.vila_caption_dict = self._cast_to_caption_vila_tokens( + caption_content=caption_content + ) if vila_spans: for page_idx, vila_span in vila_spans.items(): - spans_to_merge_dict[page_idx].extend( - vila_span) + spans_to_merge_dict[page_idx].extend(vila_span) layout_parser_span_groups_dict = defaultdict(list) if vila_spans: for page_idx, vila_span in vila_spans.items(): - layout_parser_span_groups_dict[page_idx].extend( - vila_span) + layout_parser_span_groups_dict[page_idx].extend(vila_span) return layout_parser_span_groups_dict def generate_candidates(self) -> Tuple[Union[SpanGroup, BoxGroup]]: @@ -309,41 +560,66 @@ def generate_candidates(self) -> Tuple[Union[SpanGroup, BoxGroup]]: assert self.doc.vila_span_groups merged_boxes_caption_fig_tab_dict = {} - for caption_content in ['fig', 'tab']: + for caption_content in ["fig", "tab"]: # Merge vila tokens which start with caption_content - merged_boxes_caption_fig_tab_dict[caption_content] = self.merge_vila_token_spans( - caption_content=caption_content) + merged_boxes_caption_fig_tab_dict[ + caption_content + ] = self.merge_vila_token_spans(caption_content=caption_content) - merged_boxes_caption_fig_tab_dict[caption_content] = self.get_layout_span_groups_starts_with( - caption_content=caption_content, vila_spans=merged_boxes_caption_fig_tab_dict[caption_content]) + merged_boxes_caption_fig_tab_dict[ + caption_content + ] = self.get_layout_span_groups_starts_with( + caption_content=caption_content, + vila_spans=merged_boxes_caption_fig_tab_dict[caption_content], + ) # Final check that the defined captions are starting with tab and fig - for page_idx, list_of_spans in merged_boxes_caption_fig_tab_dict[caption_content].items(): + for page_idx, list_of_spans in merged_boxes_caption_fig_tab_dict[ + caption_content + ].items(): for span in list_of_spans: - if not self.doc.symbols[span.start:span.end].lower().startswith(caption_content): + if ( + not self.doc.symbols[span.start : span.end] + .lower() + .startswith(caption_content) + ): list_of_spans.remove(span) - merged_boxes_caption_fig_tab_dict[caption_content][page_idx] = list_of_spans + merged_boxes_caption_fig_tab_dict[caption_content][ + page_idx + ] = list_of_spans # merged_boxes_vila_dict is used in figure, table boxes derivation merged_boxes_vila_dict = self.merge_vila_token_spans( - caption_content='', span_group_type=['Text', 'Paragraph', 'Table', 'Figure']) + caption_content="", span_group_type=["Text", "Paragraph", "Table", "Figure"] + ) # Create dictionary of layoutparser span groups merging boxgroups and boxes merged_boxes_vila_dict_left = None merged_boxes_fig_tab_dict = {} # List of types to be merged from layoutparser, note that sometimes figures are marked as Equations - for layout_parser_box_type in ([['Figure'], ['Table']]): - merged_boxes_vila_dict = (merged_boxes_vila_dict_left - if merged_boxes_vila_dict_left is not None else merged_boxes_vila_dict) - merged_boxes_fig_tab_dict[layout_parser_box_type[0]], merged_boxes_vila_dict_left = self.merge_boxes( + for layout_parser_box_type in [["Figure"], ["Table"]]: + merged_boxes_vila_dict = ( + merged_boxes_vila_dict_left + if merged_boxes_vila_dict_left is not None + else merged_boxes_vila_dict + ) + ( + merged_boxes_fig_tab_dict[layout_parser_box_type[0]], + merged_boxes_vila_dict_left, + ) = self.merge_boxes( layoutparser_span_groups=self.doc.blocks, types=layout_parser_box_type, - merged_boxes_vila_dict=merged_boxes_vila_dict) + merged_boxes_vila_dict=merged_boxes_vila_dict, + ) - return (merged_boxes_caption_fig_tab_dict['fig'], merged_boxes_fig_tab_dict['Figure'], - merged_boxes_caption_fig_tab_dict['tab'], merged_boxes_fig_tab_dict['Table']) + return ( + merged_boxes_caption_fig_tab_dict["fig"], + merged_boxes_fig_tab_dict["Figure"], + merged_boxes_caption_fig_tab_dict["tab"], + merged_boxes_fig_tab_dict["Table"], + ) def _predict( - self, merged_boxes_caption_dict, merged_boxes_fig_tab_dict, caption_type + self, merged_boxes_caption_dict, merged_boxes_fig_tab_dict, caption_type ) -> Dict[str, Union[SpanGroup, BoxGroup, Relation]]: """ Merges boxes corresponding to tokens of table, figure captions. For each page each caption/object create cost @@ -360,38 +636,62 @@ def _predict( predictions_captions = [] predictions_relations = [] for page in range(len(tqdm(self.doc.pages))): - if merged_boxes_caption_dict.get(page) and merged_boxes_fig_tab_dict.get(page): + if merged_boxes_caption_dict.get(page) and merged_boxes_fig_tab_dict.get( + page + ): cost_matrix = np.zeros( - (len(merged_boxes_fig_tab_dict[page]), - len(merged_boxes_caption_dict[page]))) + ( + len(merged_boxes_fig_tab_dict[page]), + len(merged_boxes_caption_dict[page]), + ) + ) for j, fig_box in enumerate(merged_boxes_fig_tab_dict[page]): for i, span_group in enumerate(merged_boxes_caption_dict[page]): caption_box = span_group.box - assert hasattr(fig_box, 'box') + assert hasattr(fig_box, "box") distance = FigureTablePredictions._get_object_caption_distance( - fig_box.box, caption_box) + fig_box.box, caption_box + ) cost_matrix[j][i] = distance - if caption_type == 'Figure': + if caption_type == "Figure": cost_matrix[j][i] = distance if distance > 0 else 900 row_ind, col_ind = linear_sum_assignment(cost_matrix) for row, col in zip(row_ind, col_ind): # Check that caption starts with tab or fig - if self.doc.symbols[ - merged_boxes_caption_dict[page][col].start: - merged_boxes_caption_dict[page][col].end].lower().startswith(caption_type.lower()[:3]): - span_group = SpanGroup(spans=[Span( - start=merged_boxes_caption_dict[page][col].start, - end=merged_boxes_caption_dict[page][col].end)], - id=len(predictions_captions) + if ( + self.doc.symbols[ + merged_boxes_caption_dict[page][col] + .start : merged_boxes_caption_dict[page][col] + .end + ] + .lower() + .startswith(caption_type.lower()[:3]) + ): + span_group = SpanGroup( + spans=[ + Span( + start=merged_boxes_caption_dict[page][col].start, + end=merged_boxes_caption_dict[page][col].end, + ) + ], + id=len(predictions_captions), + ) + box_group = BoxGroup( + boxes=[merged_boxes_fig_tab_dict[page][row].box], + id=len(predictions), ) - box_group = BoxGroup(boxes=[merged_boxes_fig_tab_dict[page][row].box], id=len(predictions)) predictions.append(box_group) predictions_captions.append(span_group) - predictions_relations.append(Relation(from_id=box_group.id, to_id=span_group.id)) - return {f'{caption_type.lower()}s': predictions, f'{caption_type.lower()}_captions': predictions_captions, - f'{caption_type.lower()}_to_{caption_type.lower()}_captions': predictions_relations} + predictions_relations.append( + Relation(from_id=box_group.id, to_id=span_group.id) + ) + return { + f"{caption_type.lower()}s": predictions, + f"{caption_type.lower()}_captions": predictions_captions, + f"{caption_type.lower()}_to_{caption_type.lower()}_captions": predictions_relations, + } def predict(self) -> Dict[str, Union[SpanGroup, BoxGroup, Relation]]: """ @@ -400,9 +700,25 @@ def predict(self) -> Dict[str, Union[SpanGroup, BoxGroup, Relation]]: information about the boundaries of figure or table. Relation stores information about the relation between caption and the object it corresponds to """ - (merged_boxes_caption_fig_dict, - merged_boxes_fig_dict, merged_boxes_caption_tab_dict, merged_boxes_tab_dict) = self.generate_candidates() + ( + merged_boxes_caption_fig_dict, + merged_boxes_fig_dict, + merged_boxes_caption_tab_dict, + merged_boxes_tab_dict, + ) = self.generate_candidates() result_dict = {} - result_dict.update(self._predict(merged_boxes_caption_fig_dict, merged_boxes_fig_dict, caption_type='Figure')) - result_dict.update(self._predict(merged_boxes_caption_tab_dict, merged_boxes_tab_dict, caption_type='Table')) + result_dict.update( + self._predict( + merged_boxes_caption_fig_dict, + merged_boxes_fig_dict, + caption_type="Figure", + ) + ) + result_dict.update( + self._predict( + merged_boxes_caption_tab_dict, + merged_boxes_tab_dict, + caption_type="Table", + ) + ) return result_dict diff --git a/src/mmda/utils/tools.py b/src/mmda/utils/tools.py index 9effac49..cbda562c 100644 --- a/src/mmda/utils/tools.py +++ b/src/mmda/utils/tools.py @@ -1,20 +1,21 @@ from __future__ import annotations -import logging -from collections import defaultdict -from itertools import groupby import itertools -from typing import List, Dict, Tuple +import logging +from typing import Dict, List, Tuple import numpy as np -from mmda.types.annotation import BoxGroup, SpanGroup -from mmda.types.box import Box -from mmda.types.span import Span +from mmda.types import BoxGroup, Document, Span, SpanGroup def allocate_overlapping_tokens_for_box( - tokens: List[SpanGroup], box, token_box_in_box_group: bool = False, x: float = 0.0, y: float = 0.0, center: bool = False + tokens: List[SpanGroup], + box, + token_box_in_box_group: bool = False, + x: float = 0.0, + y: float = 0.0, + center: bool = False, ) -> Tuple[List[Span], List[Span]]: """Finds overlap of tokens for given box Args @@ -29,10 +30,14 @@ def allocate_overlapping_tokens_for_box( """ allocated_tokens, remaining_tokens = [], [] for token in tokens: - if token_box_in_box_group and token.box_group.boxes[0].is_overlap(other=box, x=x, y=y, center=center): + if token_box_in_box_group and token.box_group.boxes[0].is_overlap( + other=box, x=x, y=y, center=center + ): # The token "box" is stored within the SpanGroup's .box_group allocated_tokens.append(token) - elif token.spans[0].box is not None and token.spans[0].box.is_overlap(other=box, x=x, y=y, center=center): + elif token.spans[0].box is not None and token.spans[0].box.is_overlap( + other=box, x=x, y=y, center=center + ): # default to assuming the token "box" is stored in the SpanGroup .box allocated_tokens.append(token) else: @@ -41,7 +46,7 @@ def allocate_overlapping_tokens_for_box( def box_groups_to_span_groups( - box_groups: List[BoxGroup], doc: Document, pad_x: bool = False, center: bool = False + box_groups: List[BoxGroup], doc: Document, pad_x: bool = False, center: bool = False ) -> List[SpanGroup]: """Generate SpanGroups from BoxGroups. Args @@ -60,23 +65,22 @@ def box_groups_to_span_groups( token_box_in_box_group = None for box_id, box_group in enumerate(box_groups): - all_tokens_overlapping_box_group = [] for box in box_group.boxes: - # Caching the page tokens to avoid duplicated search if box.page not in all_page_tokens: - cur_page_tokens = all_page_tokens[box.page] = doc.pages[ - box.page - ].tokens + cur_page_tokens = all_page_tokens[box.page] = doc.pages[box.page].tokens if token_box_in_box_group is None: # Determine whether box is stored on token SpanGroup span.box or in the box_group token_box_in_box_group = all( [ ( - (hasattr(token.box_group, "boxes") and len(token.box_group.boxes) == 1) - and token.spans[0].box is None + ( + hasattr(token.box_group, "boxes") + and len(token.box_group.boxes) == 1 + ) + and token.spans[0].box is None ) for token in cur_page_tokens ] @@ -84,9 +88,16 @@ def box_groups_to_span_groups( # Determine average width of tokens on this page if we are going to pad x if pad_x: if token_box_in_box_group and box.page not in avg_token_widths: - avg_token_widths[box.page] = np.average([t.box_group.boxes[0].w for t in cur_page_tokens]) - elif not token_box_in_box_group and box.page not in avg_token_widths: - avg_token_widths[box.page] = np.average([t.spans[0].box.w for t in cur_page_tokens]) + avg_token_widths[box.page] = np.average( + [t.box_group.boxes[0].w for t in cur_page_tokens] + ) + elif ( + not token_box_in_box_group + and box.page not in avg_token_widths + ): + avg_token_widths[box.page] = np.average( + [t.spans[0].box.w for t in cur_page_tokens] + ) else: cur_page_tokens = all_page_tokens[box.page] @@ -99,7 +110,7 @@ def box_groups_to_span_groups( # optionally pad x a small amount so that extra narrow token boxes (when split at punctuation) are not missed x=avg_token_widths.get(box.page, 0.0) * 0.5 if pad_x else 0.0, y=0.0, - center=center + center=center, ) all_page_tokens[box.page] = remaining_tokens @@ -113,7 +124,8 @@ def box_groups_to_span_groups( else MergeSpans( list_of_spans=list( itertools.chain.from_iterable( - span_group.spans for span_group in all_tokens_overlapping_box_group + span_group.spans + for span_group in all_tokens_overlapping_box_group ) ), index_distance=1, @@ -131,9 +143,11 @@ def box_groups_to_span_groups( ) if not token_box_in_box_group: - logging.warning("tokens with box stored in SpanGroup span.box will be deprecated (that is, " - "future Spans wont contain box). Ensure Document is annotated with tokens " - "having box stored in SpanGroup box_group.boxes") + logging.warning( + "tokens with box stored in SpanGroup span.box will be deprecated (that is, " + "future Spans wont contain box). Ensure Document is annotated with tokens " + "having box stored in SpanGroup box_group.boxes" + ) del all_page_tokens @@ -149,170 +163,3 @@ def box_groups_to_span_groups( # span_groups=derived_span_groups, field_name=field_name # ) return derived_span_groups - -class MergeSpans: - """ - Given w=width and h=height merge neighboring spans which are w, h or less apart or by merging neighboring spans - which are index distance apart - Inspired by https://leetcode.com/problems/merge-intervals/ - """ - - def __init__( - self, - list_of_spans: List["Span"], - w: float = 0, - h: float = 0, - index_distance: int = 1, - ) -> None: - """ - Args - w (float): The input width between boxes to merge - h (float): The input height between the boxes to merge - index_distance (int): Distance between the spans - """ - self.list_of_spans = list_of_spans - self.w = w - self.h = h - self.graph = defaultdict(list) - self.index_distance = index_distance - - @classmethod - def from_span_groups_with_box_groups( - cls, - span_groups: List["SpanGroup"], - w: float = 0, - h: float = 0, - index_distance: int = 1, - ) -> MergeSpans: - # Convert SpanGroups with single box_group box into SpanGroups with span.box - spans_with_boxes = [] - for sg in span_groups: - assert len(sg.spans) == len( - sg.box_group.boxes - ), "Unequal number of spans and boxes for SpanGroup" - for span, box in zip(sg.spans, sg.box_group.boxes): - spans_with_boxes.append(Span(start=span.start, end=span.end, box=box)) - return cls(spans_with_boxes, w, h, index_distance) - - def build_graph_index_overlap(self): - """ - Build graph, each node is represented by (start, end) of tuple, with the list of spans. Spans are considered - overlapping if they are index_distance apart - """ - starts_matrix = np.full( - (len(self.list_of_spans), len(self.list_of_spans)), - [span.start for span in self.list_of_spans] - ) - ends_matrix = np.full( - (len(self.list_of_spans), len(self.list_of_spans)), - [span.end for span in self.list_of_spans] - ) - - starts_minus_ends = np.abs(starts_matrix - ends_matrix.T) - ends_minus_starts = np.abs(ends_matrix - starts_matrix.T) - are_neighboring_spans = np.minimum(starts_minus_ends, ends_minus_starts) <= self.index_distance - neighboring_spans = np.transpose(are_neighboring_spans.nonzero()) - - if len(neighboring_spans) > 0: - neighboring_spans_no_dupes = neighboring_spans[np.where(neighboring_spans[:,1] < neighboring_spans[:,0])] - - for j, i in neighboring_spans_no_dupes: - span_i = self.list_of_spans[i] - span_j = self.list_of_spans[j] - self.graph[span_i.start, span_i.end].append(span_j) - self.graph[span_j.start, span_j.end].append(span_i) - - def build_graph_box_overlap(self): - """ - Build graph, each node is represented by (start, end) of tuple, with the list of spans with overlapping - boxes given, w, h - """ - for i, span_i in enumerate(self.list_of_spans): - assert hasattr(span_i, "box"), "Missing attribute box in a span" - for j in range(i + 1, len(self.list_of_spans)): - assert hasattr( - self.list_of_spans[j], "box" - ), "Missing attribute box in a span" - if span_i.box.is_overlap(self.list_of_spans[j].box, self.w, self.h): - self.graph[span_i.start, span_i.end].append(self.list_of_spans[j]) - self.graph[ - self.list_of_spans[j].start, self.list_of_spans[j].end - ].append(span_i) - - # gets the connected components of the boxes overlap graph. - def get_components(self): - """ - Groups connected graph nodes into dictionary list - """ - visited = set() - comp_number = 0 - nodes_in_comp = defaultdict(list) - - def mark_component_dfs(start): - stack = [start] - while stack: - span = stack.pop() - node = span.start, span.end - if node not in visited: - visited.add(node) - nodes_in_comp[comp_number].append(span) - stack.extend(self.graph[node]) - - # mark all nodes in the same connected component with the same integer. - for span in self.list_of_spans: - center = span.start, span.end - if center not in visited: - mark_component_dfs(span) - comp_number += 1 - - return nodes_in_comp, comp_number - - def merge_neighbor_spans_by_symbol_distance(self): - """ - For each of the lists of the connected nodes determined by index distance between the spans, - merge boxes and find, min, max of the index - """ - return self.build_merged_spans_from_connected_components(index=True) - - def merge_neighbor_spans_by_box_coordinate(self): - """ - For each of the lists of the connected nodes determined by distance between the boxes, - merge boxes and find, min, max of the index - """ - return self.build_merged_spans_from_connected_components(index=False) - - def build_merged_spans_from_connected_components(self, index): - """ - For each of the lists of the connected nodes determined by symbol distance or box distance, - merge boxes and find, min, max of the index - """ - if index: - self.build_graph_index_overlap() - else: - self.build_graph_box_overlap() - - nodes_in_comp, number_of_comps = self.get_components() - - # all intervals in each connected component must be merged. - merged_spans = [] - for comp in range(number_of_comps): - if nodes_in_comp[comp]: - spans_by_page: Dict[any, List[Span]] = defaultdict(list) - for pg, page_spans in groupby( - nodes_in_comp[comp], - lambda s: s.box.page if s.box is not None else None, - ): - for span in page_spans: - spans_by_page[pg].append(span) - for page_spans in spans_by_page.values(): - merged_box = Box.small_boxes_to_big_box( - [span.box for span in page_spans] - ) - merged_spans.append( - Span( - start=min([span.start for span in page_spans]), - end=max([span.end for span in page_spans]), - box=merged_box, - ) - ) - return merged_spans From d26a3790870f409b57a41a4a860f8a770c1793e9 Mon Sep 17 00:00:00 2001 From: kyleclo Date: Wed, 5 Jul 2023 22:46:26 -0700 Subject: [PATCH 19/25] revert commit --- .../figure_table_predictors.py | 542 ++-------- src/mmda/types/document.py | 15 +- src/mmda/utils/tools.py | 233 ++++- .../test_figure_table_predictors.py | 317 ++---- tests/test_utils/test_tools.py | 987 ++++-------------- 5 files changed, 656 insertions(+), 1438 deletions(-) diff --git a/src/mmda/predictors/heuristic_predictors/figure_table_predictors.py b/src/mmda/predictors/heuristic_predictors/figure_table_predictors.py index c90e4585..f0c47c85 100644 --- a/src/mmda/predictors/heuristic_predictors/figure_table_predictors.py +++ b/src/mmda/predictors/heuristic_predictors/figure_table_predictors.py @@ -1,211 +1,31 @@ from collections import defaultdict -from itertools import groupby -from typing import Dict, List, Tuple, Union +from typing import List, Dict, Tuple, Union import numpy as np from scipy.optimize import linear_sum_assignment + from tqdm import tqdm from ai2_internal import api -from ai2_internal.api import Relation -from mmda.predictors.base_predictors.base_heuristic_predictor import ( - BaseHeuristicPredictor, -) -from mmda.types import BoxGroup, SpanGroup +from mmda.predictors.base_predictors.base_heuristic_predictor import BaseHeuristicPredictor +from mmda.types import SpanGroup, BoxGroup from mmda.types.document import Document from mmda.types.span import Span - - -class MergeSpans: - """ - Given w=width and h=height merge neighboring spans which are w, h or less apart or by merging neighboring spans - which are index distance apart - Inspired by https://leetcode.com/problems/merge-intervals/ - """ - - def __init__( - self, - list_of_spans: List["Span"], - w: float = 0, - h: float = 0, - index_distance: int = 1, - ) -> None: - """ - Args - w (float): The input width between boxes to merge - h (float): The input height between the boxes to merge - index_distance (int): Distance between the spans - """ - self.list_of_spans = list_of_spans - self.w = w - self.h = h - self.graph = defaultdict(list) - self.index_distance = index_distance - - @classmethod - def from_span_groups_with_box_groups( - cls, - span_groups: List["SpanGroup"], - w: float = 0, - h: float = 0, - index_distance: int = 1, - ) -> MergeSpans: - # Convert SpanGroups with single box_group box into SpanGroups with span.box - spans_with_boxes = [] - for sg in span_groups: - assert len(sg.spans) == len( - sg.box_group.boxes - ), "Unequal number of spans and boxes for SpanGroup" - for span, box in zip(sg.spans, sg.box_group.boxes): - spans_with_boxes.append(Span(start=span.start, end=span.end, box=box)) - return cls(spans_with_boxes, w, h, index_distance) - - def build_graph_index_overlap(self): - """ - Build graph, each node is represented by (start, end) of tuple, with the list of spans. Spans are considered - overlapping if they are index_distance apart - """ - starts_matrix = np.full( - (len(self.list_of_spans), len(self.list_of_spans)), - [span.start for span in self.list_of_spans], - ) - ends_matrix = np.full( - (len(self.list_of_spans), len(self.list_of_spans)), - [span.end for span in self.list_of_spans], - ) - - starts_minus_ends = np.abs(starts_matrix - ends_matrix.T) - ends_minus_starts = np.abs(ends_matrix - starts_matrix.T) - are_neighboring_spans = ( - np.minimum(starts_minus_ends, ends_minus_starts) <= self.index_distance - ) - neighboring_spans = np.transpose(are_neighboring_spans.nonzero()) - - if len(neighboring_spans) > 0: - neighboring_spans_no_dupes = neighboring_spans[ - np.where(neighboring_spans[:, 1] < neighboring_spans[:, 0]) - ] - - for j, i in neighboring_spans_no_dupes: - span_i = self.list_of_spans[i] - span_j = self.list_of_spans[j] - self.graph[span_i.start, span_i.end].append(span_j) - self.graph[span_j.start, span_j.end].append(span_i) - - def build_graph_box_overlap(self): - """ - Build graph, each node is represented by (start, end) of tuple, with the list of spans with overlapping - boxes given, w, h - """ - for i, span_i in enumerate(self.list_of_spans): - assert hasattr(span_i, "box"), "Missing attribute box in a span" - for j in range(i + 1, len(self.list_of_spans)): - assert hasattr( - self.list_of_spans[j], "box" - ), "Missing attribute box in a span" - if span_i.box.is_overlap(self.list_of_spans[j].box, self.w, self.h): - self.graph[span_i.start, span_i.end].append(self.list_of_spans[j]) - self.graph[ - self.list_of_spans[j].start, self.list_of_spans[j].end - ].append(span_i) - - # gets the connected components of the boxes overlap graph. - def get_components(self): - """ - Groups connected graph nodes into dictionary list - """ - visited = set() - comp_number = 0 - nodes_in_comp = defaultdict(list) - - def mark_component_dfs(start): - stack = [start] - while stack: - span = stack.pop() - node = span.start, span.end - if node not in visited: - visited.add(node) - nodes_in_comp[comp_number].append(span) - stack.extend(self.graph[node]) - - # mark all nodes in the same connected component with the same integer. - for span in self.list_of_spans: - center = span.start, span.end - if center not in visited: - mark_component_dfs(span) - comp_number += 1 - - return nodes_in_comp, comp_number - - def merge_neighbor_spans_by_symbol_distance(self): - """ - For each of the lists of the connected nodes determined by index distance between the spans, - merge boxes and find, min, max of the index - """ - return self.build_merged_spans_from_connected_components(index=True) - - def merge_neighbor_spans_by_box_coordinate(self): - """ - For each of the lists of the connected nodes determined by distance between the boxes, - merge boxes and find, min, max of the index - """ - return self.build_merged_spans_from_connected_components(index=False) - - def build_merged_spans_from_connected_components(self, index): - """ - For each of the lists of the connected nodes determined by symbol distance or box distance, - merge boxes and find, min, max of the index - """ - if index: - self.build_graph_index_overlap() - else: - self.build_graph_box_overlap() - - nodes_in_comp, number_of_comps = self.get_components() - - # all intervals in each connected component must be merged. - merged_spans = [] - for comp in range(number_of_comps): - if nodes_in_comp[comp]: - spans_by_page: Dict[any, List[Span]] = defaultdict(list) - for pg, page_spans in groupby( - nodes_in_comp[comp], - lambda s: s.box.page if s.box is not None else None, - ): - for span in page_spans: - spans_by_page[pg].append(span) - for page_spans in spans_by_page.values(): - merged_box = Box.small_boxes_to_big_box( - [span.box for span in page_spans] - ) - merged_spans.append( - Span( - start=min([span.start for span in page_spans]), - end=max([span.end for span in page_spans]), - box=merged_box, - ) - ) - return merged_spans +from mmda.utils.tools import MergeSpans +from ai2_internal.api import Relation class FigureTablePredictions(BaseHeuristicPredictor): - """Class for creating a map of figure boxes to figure captions""" - - REQUIRED_DOCUMENT_FIELDS = [ - "pages", - "tokens", - "vila_span_groups", - "blocks", - ] + """Class for creating a map of figure boxes to figure captions + """ + REQUIRED_DOCUMENT_FIELDS = ['pages', 'tokens', 'vila_span_groups', 'blocks', ] def __init__(self, document: Document) -> None: self.doc = document self.vila_caption_dict = None self.vila_spans_all_dict = None self.width_heights_dict = None - self.w_avg, self.h_avg = FigureTablePredictions.get_avg_w_h_of_tokens( - self.doc.tokens - ) + self.w_avg, self.h_avg = FigureTablePredictions.get_avg_w_h_of_tokens(self.doc.tokens) # Parameteer for the fraction of the tokens classified as non-caption that are probably caption in same # Layoutparser span group self.FRACTION_OF_MISCLASSIFIED_VILA_CAPTION_TOKENS = 0.3 @@ -230,18 +50,12 @@ def get_avg_w_h_of_tokens(tokens) -> Tuple[float, float]: """ Get the average width and height of tokens """ - return np.average( - [[span.box.w, span.box.h] for token in tokens for span in token.spans], - axis=0, - ) + return np.average([[span.box.w, span.box.h] for token in tokens + for span in token.spans], axis=0) @staticmethod - def _create_dict_of_pages_spans_layoutparser( - layoutparser_span_groups, - types: List[str] = [], - starts_with: str = "", - negation: bool = False, - ) -> Dict[int, List[SpanGroup]]: + def _create_dict_of_pages_spans_layoutparser(layoutparser_span_groups, types: List[str] = [], starts_with: str = '', + negation: bool = False) -> Dict[int, List[SpanGroup]]: """ Create a dictionary of page number to list of spans, filtering or negating to the types and starts_with """ @@ -249,9 +63,7 @@ def _create_dict_of_pages_spans_layoutparser( for span_group in layoutparser_span_groups: if not types or span_group.box_group.type in types: if negation: - starts_with_bool = not span_group.text.lower().startswith( - starts_with - ) + starts_with_bool = not span_group.text.lower().startswith(starts_with) else: starts_with_bool = span_group.text.lower().startswith(starts_with) @@ -260,16 +72,13 @@ def _create_dict_of_pages_spans_layoutparser( # Creating unique start, end of spans used as a key for merging boxes box_api = api.Box.from_mmda(box) if span_group.spans and len(span_group.spans) == 1: - start, end = ( - span_group.spans[0].start, - span_group.spans[0].end, - ) + start, end = span_group.spans[0].start, span_group.spans[0].end else: start, end = -9999, -9999 - created_span = api.Span( - start=start, end=end, box=box_api - ).to_mmda() + created_span = api.Span(start=start, + end=end, + box=box_api).to_mmda() created_span.span_id = span_group.id created_span.box_group_type = span_group.box_group.type @@ -277,19 +86,19 @@ def _create_dict_of_pages_spans_layoutparser( # Bring in the boxes from the span groups for span in span_group.spans: box_api = api.Box.from_mmda(span.box) - created_span = api.Span( - start=span.start, end=span.end, box=box_api - ).to_mmda() + created_span = api.Span(start=span.start, + end=span.end, + box=box_api).to_mmda() # Note that hash output is changing everytime it is called - created_span.span_id = f"LP_span_group_{span.box.page}_{len(span_map[span.box.page])}" + created_span.span_id = f'LP_span_group_{span.box.page}_{len(span_map[span.box.page])}' created_span.box_group_type = span_group.box_group.type span_map[span.box.page].append(created_span) return span_map @staticmethod def generate_map_of_layout_to_tokens( - vila_dict, layout_parser_overlap, dict_of_pages_layoutparser, key="caption" - ) -> Dict[int, Dict]: + vila_dict, layout_parser_overlap, dict_of_pages_layoutparser, + key='caption') -> Dict[int, Dict]: """ Generate a map of layoutparser entries to the list of vila tokens with the type = key vs type != key """ @@ -297,17 +106,14 @@ def generate_map_of_layout_to_tokens( for span in vila_dict[page]: for layout_span in dict_of_pages_layoutparser.get(page, []): if span.box.is_overlap(layout_span.box): - id_dict = layout_parser_overlap.get( - layout_span.span_id, {"caption": [], "non_caption": []} - ) + id_dict = layout_parser_overlap.get(layout_span.span_id, {'caption': [], 'non_caption': []}) id_dict[key].append(span.span_id) layout_parser_overlap[layout_span.span_id] = id_dict return layout_parser_overlap @staticmethod def generate_map_of_layout_to_tokens_for_page( - vila_list: List, layout_parser_list: List, key="caption" - ) -> Dict[int, Dict]: + vila_list: List, layout_parser_list: List, key='caption') -> Dict[int, Dict]: """ Generate a map of layoutparser tokens ids to the count of vila tokens with the type = key """ @@ -315,21 +121,16 @@ def generate_map_of_layout_to_tokens_for_page( for span in vila_list: for layout_span in layout_parser_list: if span.box.is_overlap(layout_span.box): - id_dict = layout_parser_overlap.get( - layout_span.span_id, {"caption": [], "non_caption": []} - ) + id_dict = layout_parser_overlap.get(layout_span.span_id, {'caption': [], 'non_caption': []}) if span.type.lower() == key: id_dict[key].append(span.span_id) else: - id_dict["non_caption"].append(span.span_id) + id_dict['non_caption'].append(span.span_id) layout_parser_overlap[layout_span.span_id] = id_dict return layout_parser_overlap - def update_vila_caption_dict( - self, - vila_caption_dict: Dict[int, List[Span]], - vila_non_caption_dict: Dict[int, List[Span]], - ) -> Dict[int, List[Span]]: + def update_vila_caption_dict(self, vila_caption_dict: Dict[int, List[Span]], + vila_non_caption_dict: Dict[int, List[Span]]) -> Dict[int, List[Span]]: """ Update the vila caption dict to cast tokens that are misclassified as no captions in ths same LayoutParser region @@ -337,25 +138,18 @@ def update_vila_caption_dict( layout_parser_overlap = defaultdict(dict) # Build overlap map between layoutparser and caption tokens span_map = FigureTablePredictions._create_dict_of_pages_spans_layoutparser( - self.doc.blocks - ) + self.doc.blocks) layout_parser_overlap = FigureTablePredictions.generate_map_of_layout_to_tokens( - vila_caption_dict, layout_parser_overlap, span_map - ) + vila_caption_dict, layout_parser_overlap, span_map) # Build overlap map between layoutparser and non-caption tokens layout_parser_overlap = FigureTablePredictions.generate_map_of_layout_to_tokens( - vila_non_caption_dict, layout_parser_overlap, span_map, key="non_caption" - ) + vila_non_caption_dict, layout_parser_overlap, span_map, key='non_caption') for key, value in layout_parser_overlap.items(): - caption_token_fraction = len(value["caption"]) / ( - len(value["caption"]) + len(value["non_caption"]) - ) - if (1.0 > caption_token_fraction) and ( - caption_token_fraction - > self.FRACTION_OF_MISCLASSIFIED_VILA_CAPTION_TOKENS - ): - for span_id in layout_parser_overlap[key]["non_caption"]: + caption_token_fraction = len(value['caption']) / (len(value['caption']) + len(value['non_caption'])) + if ((1.0 > caption_token_fraction) and + (caption_token_fraction > self.FRACTION_OF_MISCLASSIFIED_VILA_CAPTION_TOKENS)): + for span_id in layout_parser_overlap[key]['non_caption']: for page, vila_span in vila_non_caption_dict.items(): for entry in vila_span: if entry.span_id == span_id: @@ -364,18 +158,14 @@ def update_vila_caption_dict( return vila_caption_dict @staticmethod - def _filter_span_group( - vila_span_groups: List[api.SpanGroup], - caption_content: str, - span_group_types: List[str], - negation=False, - ) -> List[api.SpanGroup]: + def _filter_span_group(vila_span_groups: List[api.SpanGroup], caption_content: str, span_group_types: List[str], + negation=False) -> List[api.SpanGroup]: """ Helper function which filters out span groups based on the caption content and span group type """ result = [] for span_group in vila_span_groups: - if span_group.text.replace(" ", "").lower().startswith(caption_content): + if span_group.text.replace(' ', '').lower().startswith(caption_content): if span_group.type in span_group_types and not negation: result.append(span_group) elif negation and span_group.type not in span_group_types: @@ -383,8 +173,7 @@ def _filter_span_group( return result def merge_vila_token_spans( - self, caption_content: str = "fig", span_group_type: List[str] = ["Caption"] - ) -> Dict[int, List[api.Box]]: + self, caption_content: str = 'fig', span_group_type: List[str] = ['Caption']) -> Dict[int, List[api.Box]]: """ Merging spanGroups Args: @@ -394,10 +183,7 @@ def merge_vila_token_spans( Returns: Dictionary page -> List of merged boxes """ vila_span_groups_filtered = FigureTablePredictions._filter_span_group( - self.doc.vila_span_groups, - caption_content=caption_content, - span_group_types=span_group_type, - ) + self.doc.vila_span_groups, caption_content=caption_content, span_group_types=span_group_type) vila_caption_dict = defaultdict(list) for entry_caption in vila_span_groups_filtered: @@ -408,12 +194,11 @@ def merge_vila_token_spans( merged_boxes_list = defaultdict(list) for page, list_of_boxes in vila_caption_dict.items(): # Merge spans if they are sufficiently close to each other - merged_boxes_list[page] = MergeSpans( - list_of_spans=list_of_boxes, w=self.w_avg * 1.5, h=self.h_avg * 1 - ).merge_neighbor_spans_by_box_coordinate() + merged_boxes_list[page] = MergeSpans(list_of_spans=list_of_boxes, w=self.w_avg * 1.5, + h=self.h_avg * 1).merge_neighbor_spans_by_box_coordinate() return merged_boxes_list - def _cast_to_caption_vila_tokens(self, caption_content="fig"): + def _cast_to_caption_vila_tokens(self, caption_content='fig'): """ Heuristic logic for fixing miss classified tokens as non caption. By checking layoutparser box predictions and tokens which belong to them, I cast the rest of the tokens to caption category. @@ -425,35 +210,19 @@ def _cast_to_caption_vila_tokens(self, caption_content="fig"): # First let's go over all the tokens which are labeled as caption and find the LayoutParser SpanGroups which # they overlap with vila_caption = FigureTablePredictions._filter_span_group( - self.doc.vila_span_groups, - caption_content=caption_content, - span_group_types=["Caption"], - ) + self.doc.vila_span_groups, caption_content=caption_content, span_group_types=['Caption']) - self.vila_caption_dict = ( - FigureTablePredictions._create_dict_of_pages_spans_vila(vila_caption) - ) + self.vila_caption_dict = FigureTablePredictions._create_dict_of_pages_spans_vila(vila_caption) vila_non_caption = FigureTablePredictions._filter_span_group( - self.doc.vila_span_groups, - caption_content="", - span_group_types=["Caption"], - negation=True, - ) - - vila_non_caption_dict = FigureTablePredictions._create_dict_of_pages_spans_vila( - vila_non_caption - ) - return self.update_vila_caption_dict( - self.vila_caption_dict, vila_non_caption_dict - ) - - def merge_boxes( - self, - layoutparser_span_groups: List[api.SpanGroup], - merged_boxes_vila_dict: Dict[int, List[api.Box]] = None, - types: List[str] = ["Figure"], - ) -> Dict[int, List[api.Box]]: + self.doc.vila_span_groups, caption_content='', span_group_types=['Caption'], negation=True) + + vila_non_caption_dict = FigureTablePredictions._create_dict_of_pages_spans_vila(vila_non_caption) + return self.update_vila_caption_dict(self.vila_caption_dict, vila_non_caption_dict) + + def merge_boxes(self, layoutparser_span_groups: List[api.SpanGroup], + merged_boxes_vila_dict: Dict[int, List[api.Box]] = None, + types: List[str] = ['Figure']) -> Dict[int, List[api.Box]]: """ Merges overlapping boxes. Vila caption predictions is more consistent than layout parser prediction, thus we check the number of items after the merge with the number of caption boxes. @@ -472,40 +241,22 @@ def merge_boxes( merged_boxes_vila_dict_left = defaultdict(list) merged_boxes_map = defaultdict(list) span_map = FigureTablePredictions._create_dict_of_pages_spans_layoutparser( - layoutparser_span_groups, types=types - ) + layoutparser_span_groups, types=types) for page, span_list in span_map.items(): # Adding vila spans to the layout parser list of the spans if merged_boxes_vila_dict[page]: span_list.extend(merged_boxes_vila_dict_left[page]) - merged_spans = MergeSpans( - span_list, w=self.w_avg * 0.5, h=self.h_avg * 1.0 - ).merge_neighbor_spans_by_box_coordinate() + merged_spans = (MergeSpans(span_list, w=self.w_avg * 0.5, h=self.h_avg * 1.0) + .merge_neighbor_spans_by_box_coordinate()) # Filtering out vila spans (not merged) - if ( - len(span_list) != len(merged_spans) - and merged_boxes_vila_dict - and merged_boxes_vila_dict[page] - ): - merged_spans = [ - merged_span - for merged_span in merged_spans - if not any( - vila_span.box.to_json() == merged_span.box.to_json() - for vila_span in merged_boxes_vila_dict[page] - ) - ] - - merged_boxes_vila_dict_left[page] = [ - vila_span - for vila_span in merged_boxes_vila_dict[page] - if any( - vila_span.box.to_json() == merged_span.box.to_json() - for merged_span in merged_spans - ) - ] + if len(span_list) != len(merged_spans) and merged_boxes_vila_dict and merged_boxes_vila_dict[page]: + merged_spans = [merged_span for merged_span in merged_spans if not any( + vila_span.box.to_json() == merged_span.box.to_json() for vila_span in merged_boxes_vila_dict[page])] + + merged_boxes_vila_dict_left[page] = [vila_span for vila_span in merged_boxes_vila_dict[page] if any( + vila_span.box.to_json() == merged_span.box.to_json() for merged_span in merged_spans)] if merged_boxes_vila_dict_left[page]: merged_boxes_vila_dict[page] = merged_boxes_vila_dict_left[page] @@ -514,9 +265,7 @@ def merge_boxes( return merged_boxes_map, merged_boxes_vila_dict @staticmethod - def _get_object_caption_distance( - figure_box: api.Box, caption_box: api.Box - ) -> float: + def _get_object_caption_distance(figure_box: api.Box, caption_box: api.Box) -> float: """ Return 900.0 if left point of figure, caption is offset more than 10% Otherwise returns distance middle of the figure box and caption box @@ -533,24 +282,24 @@ def _get_object_caption_distance( return t_cap - t_fig - def get_layout_span_groups_starts_with( - self, caption_content: str = "fig", vila_spans: dict = None - ): - """ """ + def get_layout_span_groups_starts_with(self, caption_content: str = 'fig', vila_spans: dict = None): + """ + + """ spans_to_merge_dict = defaultdict(list) - self.vila_caption_dict = self._cast_to_caption_vila_tokens( - caption_content=caption_content - ) + self.vila_caption_dict = self._cast_to_caption_vila_tokens(caption_content=caption_content) if vila_spans: for page_idx, vila_span in vila_spans.items(): - spans_to_merge_dict[page_idx].extend(vila_span) + spans_to_merge_dict[page_idx].extend( + vila_span) layout_parser_span_groups_dict = defaultdict(list) if vila_spans: for page_idx, vila_span in vila_spans.items(): - layout_parser_span_groups_dict[page_idx].extend(vila_span) + layout_parser_span_groups_dict[page_idx].extend( + vila_span) return layout_parser_span_groups_dict def generate_candidates(self) -> Tuple[Union[SpanGroup, BoxGroup]]: @@ -560,66 +309,41 @@ def generate_candidates(self) -> Tuple[Union[SpanGroup, BoxGroup]]: assert self.doc.vila_span_groups merged_boxes_caption_fig_tab_dict = {} - for caption_content in ["fig", "tab"]: + for caption_content in ['fig', 'tab']: # Merge vila tokens which start with caption_content - merged_boxes_caption_fig_tab_dict[ - caption_content - ] = self.merge_vila_token_spans(caption_content=caption_content) + merged_boxes_caption_fig_tab_dict[caption_content] = self.merge_vila_token_spans( + caption_content=caption_content) - merged_boxes_caption_fig_tab_dict[ - caption_content - ] = self.get_layout_span_groups_starts_with( - caption_content=caption_content, - vila_spans=merged_boxes_caption_fig_tab_dict[caption_content], - ) + merged_boxes_caption_fig_tab_dict[caption_content] = self.get_layout_span_groups_starts_with( + caption_content=caption_content, vila_spans=merged_boxes_caption_fig_tab_dict[caption_content]) # Final check that the defined captions are starting with tab and fig - for page_idx, list_of_spans in merged_boxes_caption_fig_tab_dict[ - caption_content - ].items(): + for page_idx, list_of_spans in merged_boxes_caption_fig_tab_dict[caption_content].items(): for span in list_of_spans: - if ( - not self.doc.symbols[span.start : span.end] - .lower() - .startswith(caption_content) - ): + if not self.doc.symbols[span.start:span.end].lower().startswith(caption_content): list_of_spans.remove(span) - merged_boxes_caption_fig_tab_dict[caption_content][ - page_idx - ] = list_of_spans + merged_boxes_caption_fig_tab_dict[caption_content][page_idx] = list_of_spans # merged_boxes_vila_dict is used in figure, table boxes derivation merged_boxes_vila_dict = self.merge_vila_token_spans( - caption_content="", span_group_type=["Text", "Paragraph", "Table", "Figure"] - ) + caption_content='', span_group_type=['Text', 'Paragraph', 'Table', 'Figure']) # Create dictionary of layoutparser span groups merging boxgroups and boxes merged_boxes_vila_dict_left = None merged_boxes_fig_tab_dict = {} # List of types to be merged from layoutparser, note that sometimes figures are marked as Equations - for layout_parser_box_type in [["Figure"], ["Table"]]: - merged_boxes_vila_dict = ( - merged_boxes_vila_dict_left - if merged_boxes_vila_dict_left is not None - else merged_boxes_vila_dict - ) - ( - merged_boxes_fig_tab_dict[layout_parser_box_type[0]], - merged_boxes_vila_dict_left, - ) = self.merge_boxes( + for layout_parser_box_type in ([['Figure'], ['Table']]): + merged_boxes_vila_dict = (merged_boxes_vila_dict_left + if merged_boxes_vila_dict_left is not None else merged_boxes_vila_dict) + merged_boxes_fig_tab_dict[layout_parser_box_type[0]], merged_boxes_vila_dict_left = self.merge_boxes( layoutparser_span_groups=self.doc.blocks, types=layout_parser_box_type, - merged_boxes_vila_dict=merged_boxes_vila_dict, - ) + merged_boxes_vila_dict=merged_boxes_vila_dict) - return ( - merged_boxes_caption_fig_tab_dict["fig"], - merged_boxes_fig_tab_dict["Figure"], - merged_boxes_caption_fig_tab_dict["tab"], - merged_boxes_fig_tab_dict["Table"], - ) + return (merged_boxes_caption_fig_tab_dict['fig'], merged_boxes_fig_tab_dict['Figure'], + merged_boxes_caption_fig_tab_dict['tab'], merged_boxes_fig_tab_dict['Table']) def _predict( - self, merged_boxes_caption_dict, merged_boxes_fig_tab_dict, caption_type + self, merged_boxes_caption_dict, merged_boxes_fig_tab_dict, caption_type ) -> Dict[str, Union[SpanGroup, BoxGroup, Relation]]: """ Merges boxes corresponding to tokens of table, figure captions. For each page each caption/object create cost @@ -636,62 +360,38 @@ def _predict( predictions_captions = [] predictions_relations = [] for page in range(len(tqdm(self.doc.pages))): - if merged_boxes_caption_dict.get(page) and merged_boxes_fig_tab_dict.get( - page - ): + if merged_boxes_caption_dict.get(page) and merged_boxes_fig_tab_dict.get(page): cost_matrix = np.zeros( - ( - len(merged_boxes_fig_tab_dict[page]), - len(merged_boxes_caption_dict[page]), - ) - ) + (len(merged_boxes_fig_tab_dict[page]), + len(merged_boxes_caption_dict[page]))) for j, fig_box in enumerate(merged_boxes_fig_tab_dict[page]): for i, span_group in enumerate(merged_boxes_caption_dict[page]): caption_box = span_group.box - assert hasattr(fig_box, "box") + assert hasattr(fig_box, 'box') distance = FigureTablePredictions._get_object_caption_distance( - fig_box.box, caption_box - ) + fig_box.box, caption_box) cost_matrix[j][i] = distance - if caption_type == "Figure": + if caption_type == 'Figure': cost_matrix[j][i] = distance if distance > 0 else 900 row_ind, col_ind = linear_sum_assignment(cost_matrix) for row, col in zip(row_ind, col_ind): # Check that caption starts with tab or fig - if ( - self.doc.symbols[ - merged_boxes_caption_dict[page][col] - .start : merged_boxes_caption_dict[page][col] - .end - ] - .lower() - .startswith(caption_type.lower()[:3]) - ): - span_group = SpanGroup( - spans=[ - Span( - start=merged_boxes_caption_dict[page][col].start, - end=merged_boxes_caption_dict[page][col].end, - ) - ], - id=len(predictions_captions), - ) - box_group = BoxGroup( - boxes=[merged_boxes_fig_tab_dict[page][row].box], - id=len(predictions), + if self.doc.symbols[ + merged_boxes_caption_dict[page][col].start: + merged_boxes_caption_dict[page][col].end].lower().startswith(caption_type.lower()[:3]): + span_group = SpanGroup(spans=[Span( + start=merged_boxes_caption_dict[page][col].start, + end=merged_boxes_caption_dict[page][col].end)], + id=len(predictions_captions) ) + box_group = BoxGroup(boxes=[merged_boxes_fig_tab_dict[page][row].box], id=len(predictions)) predictions.append(box_group) predictions_captions.append(span_group) - predictions_relations.append( - Relation(from_id=box_group.id, to_id=span_group.id) - ) - return { - f"{caption_type.lower()}s": predictions, - f"{caption_type.lower()}_captions": predictions_captions, - f"{caption_type.lower()}_to_{caption_type.lower()}_captions": predictions_relations, - } + predictions_relations.append(Relation(from_id=box_group.id, to_id=span_group.id)) + return {f'{caption_type.lower()}s': predictions, f'{caption_type.lower()}_captions': predictions_captions, + f'{caption_type.lower()}_to_{caption_type.lower()}_captions': predictions_relations} def predict(self) -> Dict[str, Union[SpanGroup, BoxGroup, Relation]]: """ @@ -700,25 +400,9 @@ def predict(self) -> Dict[str, Union[SpanGroup, BoxGroup, Relation]]: information about the boundaries of figure or table. Relation stores information about the relation between caption and the object it corresponds to """ - ( - merged_boxes_caption_fig_dict, - merged_boxes_fig_dict, - merged_boxes_caption_tab_dict, - merged_boxes_tab_dict, - ) = self.generate_candidates() + (merged_boxes_caption_fig_dict, + merged_boxes_fig_dict, merged_boxes_caption_tab_dict, merged_boxes_tab_dict) = self.generate_candidates() result_dict = {} - result_dict.update( - self._predict( - merged_boxes_caption_fig_dict, - merged_boxes_fig_dict, - caption_type="Figure", - ) - ) - result_dict.update( - self._predict( - merged_boxes_caption_tab_dict, - merged_boxes_tab_dict, - caption_type="Table", - ) - ) + result_dict.update(self._predict(merged_boxes_caption_fig_dict, merged_boxes_fig_dict, caption_type='Figure')) + result_dict.update(self._predict(merged_boxes_caption_tab_dict, merged_boxes_tab_dict, caption_type='Table')) return result_dict diff --git a/src/mmda/types/document.py b/src/mmda/types/document.py index ab5760fd..4f2457d3 100644 --- a/src/mmda/types/document.py +++ b/src/mmda/types/document.py @@ -14,7 +14,7 @@ from mmda.types.indexers import Indexer, SpanGroupIndexer from mmda.types.metadata import Metadata from mmda.types.names import ImagesField, MetadataField, SymbolsField -from mmda.utils.tools import box_groups_to_span_groups +from mmda.utils.tools import MergeSpans, allocate_overlapping_tokens_for_box, box_groups_to_span_groups class Document: @@ -46,7 +46,7 @@ def add_metadata(self, **kwargs): self.metadata.set(k, value) def annotate( - self, is_overwrite: bool = False, **kwargs: Iterable[Annotation] + self, is_overwrite: bool = False, **kwargs: Iterable[Annotation] ) -> None: """Annotate the fields for document symbols (correlating the annotations with the symbols) and store them into the papers. @@ -54,7 +54,7 @@ def annotate( # 1) check validity of field names for field_name in kwargs.keys(): assert ( - field_name not in self.SPECIAL_FIELDS + field_name not in self.SPECIAL_FIELDS ), f"The field_name {field_name} should not be in {self.SPECIAL_FIELDS}." if field_name in self.fields: @@ -83,7 +83,7 @@ def annotate( annotation_types = {type(a) for a in annotations} assert ( - len(annotation_types) == 1 + len(annotation_types) == 1 ), f"Annotations in field_name {field_name} more than 1 type: {annotation_types}" annotation_type = annotation_types.pop() @@ -94,8 +94,7 @@ def annotate( elif annotation_type == BoxGroup: # TODO: not good. BoxGroups should be stored on their own, not auto-generating SpanGroups. span_groups = self._annotate_span_group( - span_groups=box_groups_to_span_groups(annotations, self), - field_name=field_name, + span_groups=box_groups_to_span_groups(annotations, self), field_name=field_name ) else: raise NotImplementedError( @@ -112,7 +111,7 @@ def remove(self, field_name: str): del self.__indexers[field_name] def annotate_images( - self, images: Iterable[PILImage], is_overwrite: bool = False + self, images: Iterable[PILImage], is_overwrite: bool = False ) -> None: if not is_overwrite and len(self.images) > 0: raise AssertionError( @@ -134,7 +133,7 @@ def annotate_images( self.images = images def _annotate_span_group( - self, span_groups: List[SpanGroup], field_name: str + self, span_groups: List[SpanGroup], field_name: str ) -> List[SpanGroup]: """Annotate the Document using a bunch of span groups. It will associate the annotations with the document symbols. diff --git a/src/mmda/utils/tools.py b/src/mmda/utils/tools.py index cbda562c..9effac49 100644 --- a/src/mmda/utils/tools.py +++ b/src/mmda/utils/tools.py @@ -1,21 +1,20 @@ from __future__ import annotations -import itertools import logging -from typing import Dict, List, Tuple +from collections import defaultdict +from itertools import groupby +import itertools +from typing import List, Dict, Tuple import numpy as np -from mmda.types import BoxGroup, Document, Span, SpanGroup +from mmda.types.annotation import BoxGroup, SpanGroup +from mmda.types.box import Box +from mmda.types.span import Span def allocate_overlapping_tokens_for_box( - tokens: List[SpanGroup], - box, - token_box_in_box_group: bool = False, - x: float = 0.0, - y: float = 0.0, - center: bool = False, + tokens: List[SpanGroup], box, token_box_in_box_group: bool = False, x: float = 0.0, y: float = 0.0, center: bool = False ) -> Tuple[List[Span], List[Span]]: """Finds overlap of tokens for given box Args @@ -30,14 +29,10 @@ def allocate_overlapping_tokens_for_box( """ allocated_tokens, remaining_tokens = [], [] for token in tokens: - if token_box_in_box_group and token.box_group.boxes[0].is_overlap( - other=box, x=x, y=y, center=center - ): + if token_box_in_box_group and token.box_group.boxes[0].is_overlap(other=box, x=x, y=y, center=center): # The token "box" is stored within the SpanGroup's .box_group allocated_tokens.append(token) - elif token.spans[0].box is not None and token.spans[0].box.is_overlap( - other=box, x=x, y=y, center=center - ): + elif token.spans[0].box is not None and token.spans[0].box.is_overlap(other=box, x=x, y=y, center=center): # default to assuming the token "box" is stored in the SpanGroup .box allocated_tokens.append(token) else: @@ -46,7 +41,7 @@ def allocate_overlapping_tokens_for_box( def box_groups_to_span_groups( - box_groups: List[BoxGroup], doc: Document, pad_x: bool = False, center: bool = False + box_groups: List[BoxGroup], doc: Document, pad_x: bool = False, center: bool = False ) -> List[SpanGroup]: """Generate SpanGroups from BoxGroups. Args @@ -65,22 +60,23 @@ def box_groups_to_span_groups( token_box_in_box_group = None for box_id, box_group in enumerate(box_groups): + all_tokens_overlapping_box_group = [] for box in box_group.boxes: + # Caching the page tokens to avoid duplicated search if box.page not in all_page_tokens: - cur_page_tokens = all_page_tokens[box.page] = doc.pages[box.page].tokens + cur_page_tokens = all_page_tokens[box.page] = doc.pages[ + box.page + ].tokens if token_box_in_box_group is None: # Determine whether box is stored on token SpanGroup span.box or in the box_group token_box_in_box_group = all( [ ( - ( - hasattr(token.box_group, "boxes") - and len(token.box_group.boxes) == 1 - ) - and token.spans[0].box is None + (hasattr(token.box_group, "boxes") and len(token.box_group.boxes) == 1) + and token.spans[0].box is None ) for token in cur_page_tokens ] @@ -88,16 +84,9 @@ def box_groups_to_span_groups( # Determine average width of tokens on this page if we are going to pad x if pad_x: if token_box_in_box_group and box.page not in avg_token_widths: - avg_token_widths[box.page] = np.average( - [t.box_group.boxes[0].w for t in cur_page_tokens] - ) - elif ( - not token_box_in_box_group - and box.page not in avg_token_widths - ): - avg_token_widths[box.page] = np.average( - [t.spans[0].box.w for t in cur_page_tokens] - ) + avg_token_widths[box.page] = np.average([t.box_group.boxes[0].w for t in cur_page_tokens]) + elif not token_box_in_box_group and box.page not in avg_token_widths: + avg_token_widths[box.page] = np.average([t.spans[0].box.w for t in cur_page_tokens]) else: cur_page_tokens = all_page_tokens[box.page] @@ -110,7 +99,7 @@ def box_groups_to_span_groups( # optionally pad x a small amount so that extra narrow token boxes (when split at punctuation) are not missed x=avg_token_widths.get(box.page, 0.0) * 0.5 if pad_x else 0.0, y=0.0, - center=center, + center=center ) all_page_tokens[box.page] = remaining_tokens @@ -124,8 +113,7 @@ def box_groups_to_span_groups( else MergeSpans( list_of_spans=list( itertools.chain.from_iterable( - span_group.spans - for span_group in all_tokens_overlapping_box_group + span_group.spans for span_group in all_tokens_overlapping_box_group ) ), index_distance=1, @@ -143,11 +131,9 @@ def box_groups_to_span_groups( ) if not token_box_in_box_group: - logging.warning( - "tokens with box stored in SpanGroup span.box will be deprecated (that is, " - "future Spans wont contain box). Ensure Document is annotated with tokens " - "having box stored in SpanGroup box_group.boxes" - ) + logging.warning("tokens with box stored in SpanGroup span.box will be deprecated (that is, " + "future Spans wont contain box). Ensure Document is annotated with tokens " + "having box stored in SpanGroup box_group.boxes") del all_page_tokens @@ -163,3 +149,170 @@ def box_groups_to_span_groups( # span_groups=derived_span_groups, field_name=field_name # ) return derived_span_groups + +class MergeSpans: + """ + Given w=width and h=height merge neighboring spans which are w, h or less apart or by merging neighboring spans + which are index distance apart + Inspired by https://leetcode.com/problems/merge-intervals/ + """ + + def __init__( + self, + list_of_spans: List["Span"], + w: float = 0, + h: float = 0, + index_distance: int = 1, + ) -> None: + """ + Args + w (float): The input width between boxes to merge + h (float): The input height between the boxes to merge + index_distance (int): Distance between the spans + """ + self.list_of_spans = list_of_spans + self.w = w + self.h = h + self.graph = defaultdict(list) + self.index_distance = index_distance + + @classmethod + def from_span_groups_with_box_groups( + cls, + span_groups: List["SpanGroup"], + w: float = 0, + h: float = 0, + index_distance: int = 1, + ) -> MergeSpans: + # Convert SpanGroups with single box_group box into SpanGroups with span.box + spans_with_boxes = [] + for sg in span_groups: + assert len(sg.spans) == len( + sg.box_group.boxes + ), "Unequal number of spans and boxes for SpanGroup" + for span, box in zip(sg.spans, sg.box_group.boxes): + spans_with_boxes.append(Span(start=span.start, end=span.end, box=box)) + return cls(spans_with_boxes, w, h, index_distance) + + def build_graph_index_overlap(self): + """ + Build graph, each node is represented by (start, end) of tuple, with the list of spans. Spans are considered + overlapping if they are index_distance apart + """ + starts_matrix = np.full( + (len(self.list_of_spans), len(self.list_of_spans)), + [span.start for span in self.list_of_spans] + ) + ends_matrix = np.full( + (len(self.list_of_spans), len(self.list_of_spans)), + [span.end for span in self.list_of_spans] + ) + + starts_minus_ends = np.abs(starts_matrix - ends_matrix.T) + ends_minus_starts = np.abs(ends_matrix - starts_matrix.T) + are_neighboring_spans = np.minimum(starts_minus_ends, ends_minus_starts) <= self.index_distance + neighboring_spans = np.transpose(are_neighboring_spans.nonzero()) + + if len(neighboring_spans) > 0: + neighboring_spans_no_dupes = neighboring_spans[np.where(neighboring_spans[:,1] < neighboring_spans[:,0])] + + for j, i in neighboring_spans_no_dupes: + span_i = self.list_of_spans[i] + span_j = self.list_of_spans[j] + self.graph[span_i.start, span_i.end].append(span_j) + self.graph[span_j.start, span_j.end].append(span_i) + + def build_graph_box_overlap(self): + """ + Build graph, each node is represented by (start, end) of tuple, with the list of spans with overlapping + boxes given, w, h + """ + for i, span_i in enumerate(self.list_of_spans): + assert hasattr(span_i, "box"), "Missing attribute box in a span" + for j in range(i + 1, len(self.list_of_spans)): + assert hasattr( + self.list_of_spans[j], "box" + ), "Missing attribute box in a span" + if span_i.box.is_overlap(self.list_of_spans[j].box, self.w, self.h): + self.graph[span_i.start, span_i.end].append(self.list_of_spans[j]) + self.graph[ + self.list_of_spans[j].start, self.list_of_spans[j].end + ].append(span_i) + + # gets the connected components of the boxes overlap graph. + def get_components(self): + """ + Groups connected graph nodes into dictionary list + """ + visited = set() + comp_number = 0 + nodes_in_comp = defaultdict(list) + + def mark_component_dfs(start): + stack = [start] + while stack: + span = stack.pop() + node = span.start, span.end + if node not in visited: + visited.add(node) + nodes_in_comp[comp_number].append(span) + stack.extend(self.graph[node]) + + # mark all nodes in the same connected component with the same integer. + for span in self.list_of_spans: + center = span.start, span.end + if center not in visited: + mark_component_dfs(span) + comp_number += 1 + + return nodes_in_comp, comp_number + + def merge_neighbor_spans_by_symbol_distance(self): + """ + For each of the lists of the connected nodes determined by index distance between the spans, + merge boxes and find, min, max of the index + """ + return self.build_merged_spans_from_connected_components(index=True) + + def merge_neighbor_spans_by_box_coordinate(self): + """ + For each of the lists of the connected nodes determined by distance between the boxes, + merge boxes and find, min, max of the index + """ + return self.build_merged_spans_from_connected_components(index=False) + + def build_merged_spans_from_connected_components(self, index): + """ + For each of the lists of the connected nodes determined by symbol distance or box distance, + merge boxes and find, min, max of the index + """ + if index: + self.build_graph_index_overlap() + else: + self.build_graph_box_overlap() + + nodes_in_comp, number_of_comps = self.get_components() + + # all intervals in each connected component must be merged. + merged_spans = [] + for comp in range(number_of_comps): + if nodes_in_comp[comp]: + spans_by_page: Dict[any, List[Span]] = defaultdict(list) + for pg, page_spans in groupby( + nodes_in_comp[comp], + lambda s: s.box.page if s.box is not None else None, + ): + for span in page_spans: + spans_by_page[pg].append(span) + for page_spans in spans_by_page.values(): + merged_box = Box.small_boxes_to_big_box( + [span.box for span in page_spans] + ) + merged_spans.append( + Span( + start=min([span.start for span in page_spans]), + end=max([span.end for span in page_spans]), + box=merged_box, + ) + ) + return merged_spans diff --git a/tests/test_predictors/test_figure_table_predictors.py b/tests/test_predictors/test_figure_table_predictors.py index d0f6fa7b..1151f6b6 100644 --- a/tests/test_predictors/test_figure_table_predictors.py +++ b/tests/test_predictors/test_figure_table_predictors.py @@ -1,71 +1,54 @@ import json -import pathlib import pickle import unittest from collections import defaultdict - +import pathlib import pytest from ai2_internal.api import Relation -from mmda.predictors.heuristic_predictors.figure_table_predictors import ( - FigureTablePredictions, -) -from mmda.types import Box, BoxGroup, Document, Span +from mmda.predictors.heuristic_predictors.figure_table_predictors import FigureTablePredictions +from mmda.types import Document, BoxGroup +from mmda.types.box import Box +from mmda.types.span import Span class TestFigureCaptionPredictor(unittest.TestCase): @classmethod def setUp(cls): cls.fixture_path = pathlib.Path(__file__).parent.parent - with open( - cls.fixture_path - / "fixtures/doc_fixture_e5910c027af0ee9c1901c57f6579d903aedee7f4.pkl", - "rb", - ) as file_handle: + with open(cls.fixture_path / 'fixtures/doc_fixture_e5910c027af0ee9c1901c57f6579d903aedee7f4.pkl', + 'rb') as file_handle: doc_json = pickle.load(file_handle) cls.doc = Document.from_json(doc_json) assert cls.doc.pages assert cls.doc.tokens assert cls.doc.blocks assert cls.doc.vila_span_groups - with open( - cls.fixture_path - / "fixtures/doc_fixture_2149e0c1106e6dfa36ea787167d6611cf88b69cb.json", - "rb", - ) as file_handle: + with open(cls.fixture_path / 'fixtures/doc_fixture_2149e0c1106e6dfa36ea787167d6611cf88b69cb.json', + 'rb') as file_handle: dic_json = json.load(file_handle) - cls.doc_2 = Document.from_json(dic_json["doc"]) - layout_equations = [ - BoxGroup.from_json(entry) for entry in dic_json["layout_equations"] - ] + cls.doc_2 = Document.from_json(dic_json['doc']) + layout_equations = [BoxGroup.from_json(entry) for entry in dic_json['layout_equations']] cls.doc_2.annotate(blocks=layout_equations) - with open( - cls.fixture_path / "fixtures/figure_table_predictions.json", "r" - ) as file: + with open(cls.fixture_path / 'fixtures/figure_table_predictions.json', 'r') as file: cls.figure_predictions = json.load(file) cls.figure_table_predictor = FigureTablePredictions(cls.doc) def test_merge_boxes(self): - result = self.figure_table_predictor.merge_boxes( - self.doc.blocks, defaultdict(list) - ) + result = self.figure_table_predictor.merge_boxes(self.doc.blocks, defaultdict(list)) assert list(result[0].keys()) == [0, 2, 3, 7] assert isinstance(result[0][0][0], Span) def test_get_figure_caption_distance(self): distance = FigureTablePredictions._get_object_caption_distance( - Box(l=0.2, t=0.2, w=0.1, h=0.1, page=0), - Box(l=0.3, t=0.3, w=0.1, h=0.1, page=0), - ) + Box(l=0.2, t=0.2, w=0.1, h=0.1, page=0), Box(l=0.3, t=0.3, w=0.1, h=0.1, page=0)) assert distance == 900 distance = FigureTablePredictions._get_object_caption_distance( - Box(l=0.2, t=0.2, w=0.1, h=0.1, page=0), - Box(l=0.2, t=0.3, w=0.1, h=0.1, page=0), - ) + Box(l=0.2, t=0.2, w=0.1, h=0.1, page=0), Box(l=0.2, t=0.3, w=0.1, h=0.1, page=0)) assert distance == pytest.approx(0.15) @@ -74,17 +57,12 @@ def test_generate_map_of_layout_to_tokens(self): Test that the function generates a map of layout to tokens using """ vila_caption = FigureTablePredictions._filter_span_group( - self.doc.vila_span_groups, - caption_content="fig", - span_group_types=["Caption"], - ) + self.doc.vila_span_groups, caption_content='fig', span_group_types=['Caption']) - vila_caption_dict = FigureTablePredictions._create_dict_of_pages_spans_vila( - vila_caption - ) + vila_caption_dict = FigureTablePredictions._create_dict_of_pages_spans_vila(vila_caption) result = self.figure_table_predictor.generate_map_of_layout_to_tokens( - vila_caption_dict, defaultdict(list), defaultdict(list) - ) + vila_caption_dict, + defaultdict(list), defaultdict(list)) assert list(result.keys()) == [] def test_predict_e5910c027af0ee9c1901c57f6579d903aedee7f4(self): @@ -94,134 +72,71 @@ def test_predict_e5910c027af0ee9c1901c57f6579d903aedee7f4(self): """ result = self.figure_table_predictor.predict() assert isinstance(result, dict) - assert list(result.keys()) == [ - "figures", - "figure_captions", - "figure_to_figure_captions", - "tables", - "table_captions", - "table_to_table_captions", - ] - assert len(result["figures"]) == 4 - assert len(result["tables"]) == 4 - assert isinstance(result["figure_to_figure_captions"][0], Relation) - assert isinstance(result["table_to_table_captions"][0], Relation) - assert [figure.to_json() for figure in result["figures"]] == [ - { - "boxes": [ - { - "height": 0.130624674787425, - "left": 0.5021962683185254, - "page": 0, - "top": 0.3574526237718987, - "width": 0.3930938321780535, - } - ] - }, - { - "boxes": [ - { - "height": 0.21034525861643782, - "left": 0.08724006952023973, - "page": 2, - "top": 0.09557842485832446, - "width": 0.3754700804068372, - } - ], - "id": 1, - }, - { - "boxes": [ - { - "height": 0.31222110318652835, - "left": 0.08188235294117646, - "page": 3, - "top": 0.08723311954074436, - "width": 0.37919526861851516, - } - ], - "id": 2, - }, - { - "boxes": [ - { - "height": 0.3527590433756511, - "left": 0.09958468543158637, - "page": 7, - "top": 0.08601251274648339, - "width": 0.8034834020278033, - } - ], - "id": 3, - }, - ] - assert [ - figure_caption.to_json() for figure_caption in result["figure_captions"] - ] == [ - {"id": 0, "metadata": {}, "spans": [{"end": 2057, "start": 2034}]}, - {"id": 1, "metadata": {}, "spans": [{"end": 9679, "start": 9175}]}, - {"id": 2, "metadata": {}, "spans": [{"end": 13875, "start": 13822}]}, - {"id": 3, "metadata": {}, "spans": [{"end": 31364, "start": 31224}]}, - ] - - assert [table.to_json() for table in result["tables"]] == [ - { - "boxes": [ - { - "height": 0.2796805025351168, - "left": 0.16789371515411178, - "page": 4, - "top": 0.1370883614125878, - "width": 0.6443845462175756, - } - ] - }, - { - "boxes": [ - { - "height": 0.20913203075678666, - "left": 0.1747694701151131, - "page": 5, - "top": 0.13721680882001164, - "width": 0.622537251391442, - } - ], - "id": 1, - }, - { - "boxes": [ - { - "height": 0.06003320096719145, - "left": 0.15402431114047183, - "page": 5, - "top": 0.5840287642045454, - "width": 0.2569979998021344, - } - ], - "id": 2, - }, - { - "boxes": [ - { - "height": 0.23519277090978136, - "left": 0.5027104296715431, - "page": 6, - "top": 0.27805763784081045, - "width": 0.3950077131682751, - } - ], - "id": 3, - }, - ] - - assert [ - table_caption.to_json() for table_caption in result["table_captions"] - ] == [ - {"id": 0, "metadata": {}, "spans": [{"end": 18359, "start": 18198}]}, - {"id": 1, "metadata": {}, "spans": [{"end": 22214, "start": 22042}]}, - {"id": 2, "metadata": {}, "spans": [{"end": 23502, "start": 23400}]}, - {"id": 3, "metadata": {}, "spans": [{"end": 29584, "start": 29369}]}, - ] + assert list(result.keys()) == ['figures', 'figure_captions', 'figure_to_figure_captions', 'tables', + 'table_captions', + 'table_to_table_captions', ] + assert len(result['figures']) == 4 + assert len(result['tables']) == 4 + assert isinstance(result['figure_to_figure_captions'][0], Relation) + assert isinstance(result['table_to_table_captions'][0], Relation) + assert [figure.to_json() for figure in result['figures']] == [{'boxes': [{'height': 0.130624674787425, + 'left': 0.5021962683185254, + 'page': 0, + 'top': 0.3574526237718987, + 'width': 0.3930938321780535}]}, + {'boxes': [{'height': 0.21034525861643782, + 'left': 0.08724006952023973, + 'page': 2, + 'top': 0.09557842485832446, + 'width': 0.3754700804068372}], + 'id': 1}, + {'boxes': [{'height': 0.31222110318652835, + 'left': 0.08188235294117646, + 'page': 3, + 'top': 0.08723311954074436, + 'width': 0.37919526861851516}], + 'id': 2}, + {'boxes': [{'height': 0.3527590433756511, + 'left': 0.09958468543158637, + 'page': 7, + 'top': 0.08601251274648339, + 'width': 0.8034834020278033}], + 'id': 3}] + assert [figure_caption.to_json() for figure_caption in result['figure_captions']] == [ + {'id': 0, 'metadata': {}, 'spans': [{'end': 2057, 'start': 2034}]}, + {'id': 1, 'metadata': {}, 'spans': [{'end': 9679, 'start': 9175}]}, + {'id': 2, 'metadata': {}, 'spans': [{'end': 13875, 'start': 13822}]}, + {'id': 3, 'metadata': {}, 'spans': [{'end': 31364, 'start': 31224}]}] + + assert [table.to_json() for table in result['tables']] == [{'boxes': [{'height': 0.2796805025351168, + 'left': 0.16789371515411178, + 'page': 4, + 'top': 0.1370883614125878, + 'width': 0.6443845462175756}]}, + {'boxes': [{'height': 0.20913203075678666, + 'left': 0.1747694701151131, + 'page': 5, + 'top': 0.13721680882001164, + 'width': 0.622537251391442}], + 'id': 1}, + {'boxes': [{'height': 0.06003320096719145, + 'left': 0.15402431114047183, + 'page': 5, + 'top': 0.5840287642045454, + 'width': 0.2569979998021344}], + 'id': 2}, + {'boxes': [{'height': 0.23519277090978136, + 'left': 0.5027104296715431, + 'page': 6, + 'top': 0.27805763784081045, + 'width': 0.3950077131682751}], + 'id': 3}] + + assert [table_caption.to_json() for table_caption in result['table_captions']] == [ + {'id': 0, 'metadata': {}, 'spans': [{'end': 18359, 'start': 18198}]}, + {'id': 1, 'metadata': {}, 'spans': [{'end': 22214, 'start': 22042}]}, + {'id': 2, 'metadata': {}, 'spans': [{'end': 23502, 'start': 23400}]}, + {'id': 3, 'metadata': {}, 'spans': [{'end': 29584, 'start': 29369}]}] def test_predict_2149e0c1106e6dfa36ea787167d6611cf88b69cb(self): """ @@ -231,40 +146,30 @@ def test_predict_2149e0c1106e6dfa36ea787167d6611cf88b69cb(self): self.figure_table_predictor.doc = self.doc_2 result = self.figure_table_predictor.predict() assert isinstance(result, dict) - assert list(result.keys()) == [ - "figures", - "figure_captions", - "figure_to_figure_captions", - "tables", - "table_captions", - "table_to_table_captions", - ] - assert len(result["figures"]) == 19 - assert len(result["tables"]) == 0 - assert isinstance(result["figure_to_figure_captions"][0], Relation) - assert [ - figure.to_json() for figure in result["figures"] - ] == self.figure_predictions - assert [ - figure_caption.to_json() for figure_caption in result["figure_captions"] - ] == [ - {"id": 0, "metadata": {}, "spans": [{"end": 5253, "start": 5019}]}, - {"id": 1, "metadata": {}, "spans": [{"end": 9230, "start": 8976}]}, - {"id": 2, "metadata": {}, "spans": [{"end": 13164, "start": 12935}]}, - {"id": 3, "metadata": {}, "spans": [{"end": 17600, "start": 17373}]}, - {"id": 4, "metadata": {}, "spans": [{"end": 23624, "start": 23205}]}, - {"id": 5, "metadata": {}, "spans": [{"end": 21009, "start": 20070}]}, - {"id": 6, "metadata": {}, "spans": [{"end": 28975, "start": 28838}]}, - {"id": 7, "metadata": {}, "spans": [{"end": 32839, "start": 32681}]}, - {"id": 8, "metadata": {}, "spans": [{"end": 37061, "start": 36394}]}, - {"id": 9, "metadata": {}, "spans": [{"end": 42245, "start": 42063}]}, - {"id": 10, "metadata": {}, "spans": [{"end": 43512, "start": 43418}]}, - {"id": 11, "metadata": {}, "spans": [{"end": 46726, "start": 46542}]}, - {"id": 12, "metadata": {}, "spans": [{"end": 50359, "start": 50192}]}, - {"id": 13, "metadata": {}, "spans": [{"end": 57779, "start": 57323}]}, - {"id": 14, "metadata": {}, "spans": [{"end": 60918, "start": 60838}]}, - {"id": 15, "metadata": {}, "spans": [{"end": 64943, "start": 64238}]}, - {"id": 16, "metadata": {}, "spans": [{"end": 69170, "start": 68548}]}, - {"id": 17, "metadata": {}, "spans": [{"end": 75951, "start": 75767}]}, - {"id": 18, "metadata": {}, "spans": [{"end": 80129, "start": 79561}]}, - ] + assert list(result.keys()) == ['figures', 'figure_captions', 'figure_to_figure_captions', 'tables', + 'table_captions', + 'table_to_table_captions', ] + assert len(result['figures']) == 19 + assert len(result['tables']) == 0 + assert isinstance(result['figure_to_figure_captions'][0], Relation) + assert [figure.to_json() for figure in result['figures']] == self.figure_predictions + assert [figure_caption.to_json() for figure_caption in result['figure_captions']] == [ + {'id': 0, 'metadata': {}, 'spans': [{'end': 5253, 'start': 5019}]}, + {'id': 1, 'metadata': {}, 'spans': [{'end': 9230, 'start': 8976}]}, + {'id': 2, 'metadata': {}, 'spans': [{'end': 13164, 'start': 12935}]}, + {'id': 3, 'metadata': {}, 'spans': [{'end': 17600, 'start': 17373}]}, + {'id': 4, 'metadata': {}, 'spans': [{'end': 23624, 'start': 23205}]}, + {'id': 5, 'metadata': {}, 'spans': [{'end': 21009, 'start': 20070}]}, + {'id': 6, 'metadata': {}, 'spans': [{'end': 28975, 'start': 28838}]}, + {'id': 7, 'metadata': {}, 'spans': [{'end': 32839, 'start': 32681}]}, + {'id': 8, 'metadata': {}, 'spans': [{'end': 37061, 'start': 36394}]}, + {'id': 9, 'metadata': {}, 'spans': [{'end': 42245, 'start': 42063}]}, + {'id': 10, 'metadata': {}, 'spans': [{'end': 43512, 'start': 43418}]}, + {'id': 11, 'metadata': {}, 'spans': [{'end': 46726, 'start': 46542}]}, + {'id': 12, 'metadata': {}, 'spans': [{'end': 50359, 'start': 50192}]}, + {'id': 13, 'metadata': {}, 'spans': [{'end': 57779, 'start': 57323}]}, + {'id': 14, 'metadata': {}, 'spans': [{'end': 60918, 'start': 60838}]}, + {'id': 15, 'metadata': {}, 'spans': [{'end': 64943, 'start': 64238}]}, + {'id': 16, 'metadata': {}, 'spans': [{'end': 69170, 'start': 68548}]}, + {'id': 17, 'metadata': {}, 'spans': [{'end': 75951, 'start': 75767}]}, + {'id': 18, 'metadata': {}, 'spans': [{'end': 80129, 'start': 79561}]}] diff --git a/tests/test_utils/test_tools.py b/tests/test_utils/test_tools.py index e6c1c6e5..6b39bdb1 100644 --- a/tests/test_utils/test_tools.py +++ b/tests/test_utils/test_tools.py @@ -10,15 +10,18 @@ import unittest from mmda.types.annotation import BoxGroup, SpanGroup +from mmda.types.span import Span from mmda.types.box import Box from mmda.types.document import Document -from mmda.types.span import Span -from mmda.utils.tools import MergeSpans, box_groups_to_span_groups + +from mmda.utils.tools import MergeSpans +from mmda.utils.tools import box_groups_to_span_groups fixture_path = pathlib.Path(__file__).parent.parent / "fixtures" / "utils" class TestMergeNeighborSpans(unittest.TestCase): + def test_merge_multiple_neighbor_spans(self): spans = [Span(start=0, end=10), Span(start=11, end=20), Span(start=21, end=30)] merge_spans = MergeSpans(list_of_spans=spans, index_distance=1) @@ -51,9 +54,7 @@ def test_different_index_distances(self): def test_zero_index_distance(self): spans = [Span(start=0, end=10), Span(start=10, end=20)] - out = MergeSpans( - list_of_spans=spans, index_distance=0 - ).merge_neighbor_spans_by_symbol_distance() + out = MergeSpans(list_of_spans=spans, index_distance=0).merge_neighbor_spans_by_symbol_distance() assert len(out) == 1 assert isinstance(out[0], Span) assert out[0].start == 0 @@ -63,7 +64,7 @@ def test_handling_of_boxes(self): spans = [ Span(start=0, end=10, box=Box(l=0, t=0, w=1, h=1, page=0)), Span(start=11, end=20, box=Box(l=1, t=1, w=2, h=2, page=0)), - Span(start=21, end=150, box=Box(l=2, t=2, w=3, h=3, page=1)), + Span(start=21, end=150, box=Box(l=2, t=2, w=3, h=3, page=1)) ] merge_spans = MergeSpans(list_of_spans=spans, index_distance=1) merge_spans.merge_neighbor_spans_by_symbol_distance() @@ -83,7 +84,7 @@ def test_handling_of_boxes(self): spans = [ Span(start=0, end=10, box=Box(l=0, t=0, w=1, h=1, page=1)), Span(start=11, end=20, box=Box(l=1, t=1, w=2, h=2, page=1)), - Span(start=100, end=150, box=Box(l=2, t=2, w=3, h=3, page=1)), + Span(start=100, end=150, box=Box(l=2, t=2, w=3, h=3, page=1)) ] merge_spans = MergeSpans(list_of_spans=spans, index_distance=1) @@ -103,7 +104,7 @@ def test_handling_of_boxes(self): Span(start=0, end=10, box=Box(l=0, t=0, w=1, h=1, page=0)), Span(start=11, end=20), Span(start=21, end=150), - Span(start=155, end=200), + Span(start=155, end=200) ] merge_spans = MergeSpans(list_of_spans=spans, index_distance=1) merge_spans.merge_neighbor_spans_by_symbol_distance() @@ -125,721 +126,223 @@ def test_handling_of_boxes(self): list_of_spans_to_merge = [ - Span( - start=3944, - end=3948, - box=Box( - l=0.19238134915568578, - t=0.22752901673615306, - w=0.06941334053447479, - h=0.029442207414270286, - page=4, - ), - ), - Span( - start=3949, - end=3951, - box=Box( - l=0.27220460878651254, - t=0.22752901673615306, - w=0.03468585042904468, - h=0.029442207414270286, - page=4, - ), - ), - Span( - start=4060, - end=4063, - box=Box( - l=0.4204075769894973, - t=0.34144142726484455, - w=0.023417310961637895, - h=0.014200429984914883, - page=4, - ), - ), - Span( - start=4072, - end=4075, - box=Box( - l=0.5182742633669088, - t=0.34144142726484455, - w=0.029000512031393755, - h=0.014200429984914883, - page=4, - ), - ), - Span( - start=4076, - end=4083, - box=Box( - l=0.5522956396696659, - t=0.34144142726484455, - w=0.06440764687304719, - h=0.014200429984914883, - page=4, - ), - ), - Span( - start=4119, - end=4128, - box=Box( - l=0.2686971421659869, - t=0.36273518298114954, - w=0.08479235581478171, - h=0.014200429984914883, - page=4, - ), - ), - Span( - start=4134, - end=4144, - box=Box( - l=0.40387889180816966, - t=0.36273518298114954, - w=0.08368776567508182, - h=0.014200429984914883, - page=4, - ), - ), - Span( - start=4145, - end=4148, - box=Box( - l=0.4943548659781345, - t=0.36273518298114954, - w=0.042396177907390975, - h=0.014200429984914883, - page=4, - ), - ), - Span( - start=4149, - end=4162, - box=Box( - l=0.5435392523804085, - t=0.36273518298114954, - w=0.11491754144296094, - h=0.014200429984914883, - page=4, - ), - ), - Span( - start=4166, - end=4177, - box=Box( - l=0.6876581404256177, - t=0.36273518298114954, - w=0.09146006356715199, - h=0.014200429984914883, - page=4, - ), - ), - Span( - start=4419, - end=4427, - box=Box( - l=0.2686971421659869, - t=0.4479113936500019, - w=0.06846450520430858, - h=0.014200429984914883, - page=4, - ), - ), - Span( - start=4497, - end=4505, - box=Box( - l=0.2686971421659869, - t=0.46920514936630686, - w=0.06846450520430858, - h=0.014200429984914883, - page=4, - ), - ), - Span( - start=4517, - end=4520, - box=Box( - l=0.42195400318507725, - t=0.46920514936630686, - w=0.029000512031393755, - h=0.014200429984914883, - page=4, - ), - ), - Span( - start=4574, - end=4581, - box=Box( - l=0.2686971421659869, - t=0.49049890508261185, - w=0.07810456460532592, - h=0.014200429984914883, - page=4, - ), - ), - Span( - start=4582, - end=4587, - box=Box( - l=0.35061756361754887, - t=0.49049890508261185, - w=0.03904224057412029, - h=0.014200429984914883, - page=4, - ), - ), - Span( - start=4588, - end=4591, - box=Box( - l=0.39347566103790516, - t=0.49049890508261185, - w=0.023417310961637943, - h=0.014200429984914883, - page=4, - ), - ), - Span( - start=4592, - end=4601, - box=Box( - l=0.4207088288457791, - t=0.49049890508261185, - w=0.08254300862121101, - h=0.014200429984914883, - page=4, - ), - ), - Span( - start=4602, - end=4613, - box=Box( - l=0.5070676943132262, - t=0.49049890508261185, - w=0.09481400090042272, - h=0.014200429984914883, - page=4, - ), - ), -] - -list_of_spans_to_merge_2 = [ - Span( - start=30113, - end=30119, - box=Box( - l=0.12095229775767885, - t=0.3578497466414853, - w=0.05243790645011725, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30120, - end=30124, - box=Box( - l=0.17929474059091924, - t=0.3578497466414853, - w=0.030687522426571887, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30125, - end=30129, - box=Box( - l=0.21799556239458678, - t=0.3578497466414853, - w=0.04350076804709073, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30130, - end=30135, - box=Box( - l=0.26740086682480063, - t=0.3578497466414853, - w=0.050208642713631964, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30136, - end=30141, - box=Box( - l=0.32351404592155575, - t=0.3578497466414853, - w=0.0446254416438761, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30142, - end=30151, - box=Box( - l=0.37404402394855496, - t=0.3578497466414853, - w=0.0769598075514552, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30152, - end=30155, - box=Box( - l=0.4569284513402187, - t=0.3578497466414853, - w=0.029000512031393852, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30156, - end=30165, - box=Box( - l=0.4918334997547357, - t=0.3578497466414853, - w=0.0792091547450259, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30166, - end=30175, - box=Box( - l=0.5769471908828846, - t=0.3578497466414853, - w=0.07175819216632291, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30176, - end=30179, - box=Box( - l=0.6576023545380633, - t=0.3578497466414853, - w=0.03122977576787907, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30180, - end=30184, - box=Box( - l=0.6947366666890655, - t=0.3578497466414853, - w=0.03904224057412024, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30185, - end=30190, - box=Box( - l=0.7396834436463088, - t=0.3578497466414853, - w=0.05020864271363187, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30191, - end=30193, - box=Box( - l=0.7957966227430638, - t=0.3578497466414853, - w=0.015624929612482252, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30194, - end=30197, - box=Box( - l=0.12095229775767885, - t=0.37500875791374183, - w=0.024541984558423317, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30198, - end=30207, - box=Box( - l=0.1518205712980198, - t=0.37500875791374183, - w=0.07695980755145514, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30208, - end=30210, - box=Box( - l=0.2351066678313926, - t=0.37500875791374183, - w=0.013395665875996984, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30211, - end=30214, - box=Box( - l=0.2548286226893072, - t=0.37500875791374183, - w=0.02231272082193805, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30215, - end=30217, - box=Box( - l=0.283467632493163, - t=0.37500875791374183, - w=0.015624929612482252, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30218, - end=30221, - box=Box( - l=0.3054188510875629, - t=0.37500875791374183, - w=0.024541984558423317, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30222, - end=30229, - box=Box( - l=0.33628712462790383, - t=0.37500875791374183, - w=0.055570925755447906, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30230, - end=30235, - box=Box( - l=0.3981843393652693, - t=0.37500875791374183, - w=0.04183384110899822, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30236, - end=30240, - box=Box( - l=0.44668588822663785, - t=0.37500875791374183, - w=0.03570838669793504, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30241, - end=30244, - box=Box( - l=0.4887205639064905, - t=0.37500875791374183, - w=0.020083457085452783, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30245, - end=30255, - box=Box( - l=0.5151303099738609, - t=0.37500875791374183, - w=0.08810612623388145, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30256, - end=30259, - box=Box( - l=0.6095627251896601, - t=0.37500875791374183, - w=0.022312720821938, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30260, - end=30262, - box=Box( - l=0.6382017349935157, - t=0.37500875791374183, - w=0.015624929612482252, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30263, - end=30268, - box=Box( - l=0.6601529535879158, - t=0.37500875791374183, - w=0.03958449391542752, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30269, - end=30273, - box=Box( - l=0.7098795933314969, - t=0.37500875791374183, - w=0.035708386697935225, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30274, - end=30276, - box=Box( - l=0.7519142690113497, - t=0.37500875791374183, - w=0.013395665875997033, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30277, - end=30278, - box=Box( - l=0.7716362238692644, - t=0.37500875791374183, - w=0.008917054945941066, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30279, - end=30281, - box=Box( - l=0.7868795677971232, - t=0.37500875791374183, - w=0.02454198455842322, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30282, - end=30291, - box=Box( - l=0.12095229775767885, - t=0.3921677691859983, - w=0.08031374488472577, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30292, - end=30296, - box=Box( - l=0.2062869069137678, - t=0.3921677691859983, - w=0.03904224057412024, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30297, - end=30302, - box=Box( - l=0.25035001175925126, - t=0.3921677691859983, - w=0.050208642713631964, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30303, - end=30311, - box=Box( - l=0.30557951874424644, - t=0.3921677691859983, - w=0.08143841848151108, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30312, - end=30314, - box=Box( - l=0.3920388014971207, - t=0.3921677691859983, - w=0.016729519752182193, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30315, - end=30321, - box=Box( - l=0.4137891855206661, - t=0.3921677691859983, - w=0.0535625800469026, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30322, - end=30328, - box=Box( - l=0.47237262983893197, - t=0.3921677691859983, - w=0.05354249658981717, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30329, - end=30333, - box=Box( - l=0.5309359907001122, - t=0.3921677691859983, - w=0.03681297683763493, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30334, - end=30336, - box=Box( - l=0.5727698318091105, - t=0.3921677691859983, - w=0.01672951975218224, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30337, - end=30344, - box=Box( - l=0.5945202158326559, - t=0.3921677691859983, - w=0.060230287799273016, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30345, - end=30348, - box=Box( - l=0.6597713679032922, - t=0.3921677691859983, - w=0.029000512031393946, - h=0.014200429984914883, - page=19, - ), - ), - Span( - start=30349, - end=30359, - box=Box( - l=0.6937927442060494, - t=0.3921677691859983, - w=0.07834556609035141, - h=0.014200429984914883, - page=19, - ), - ), -] + Span(start=3944, end=3948, + box=Box(l=0.19238134915568578, t=0.22752901673615306, w=0.06941334053447479, h=0.029442207414270286, + page=4)), + Span(start=3949, end=3951, + box=Box(l=0.27220460878651254, t=0.22752901673615306, w=0.03468585042904468, h=0.029442207414270286, + page=4)), + Span(start=4060, end=4063, + box=Box(l=0.4204075769894973, t=0.34144142726484455, w=0.023417310961637895, h=0.014200429984914883, + page=4)), + Span(start=4072, end=4075, + box=Box(l=0.5182742633669088, t=0.34144142726484455, w=0.029000512031393755, h=0.014200429984914883, + page=4)), + Span(start=4076, end=4083, + box=Box(l=0.5522956396696659, t=0.34144142726484455, w=0.06440764687304719, h=0.014200429984914883, + page=4)), + Span(start=4119, end=4128, + box=Box(l=0.2686971421659869, t=0.36273518298114954, w=0.08479235581478171, h=0.014200429984914883, + page=4)), + Span(start=4134, end=4144, + box=Box(l=0.40387889180816966, t=0.36273518298114954, w=0.08368776567508182, h=0.014200429984914883, + page=4)), + Span(start=4145, end=4148, + box=Box(l=0.4943548659781345, t=0.36273518298114954, w=0.042396177907390975, h=0.014200429984914883, + page=4)), + Span(start=4149, end=4162, + box=Box(l=0.5435392523804085, t=0.36273518298114954, w=0.11491754144296094, h=0.014200429984914883, + page=4)), + Span(start=4166, end=4177, + box=Box(l=0.6876581404256177, t=0.36273518298114954, w=0.09146006356715199, h=0.014200429984914883, + page=4)), + Span(start=4419, end=4427, + box=Box(l=0.2686971421659869, t=0.4479113936500019, w=0.06846450520430858, h=0.014200429984914883, + page=4)), + Span(start=4497, end=4505, + box=Box(l=0.2686971421659869, t=0.46920514936630686, w=0.06846450520430858, h=0.014200429984914883, + page=4)), + Span(start=4517, end=4520, + box=Box(l=0.42195400318507725, t=0.46920514936630686, w=0.029000512031393755, h=0.014200429984914883, + page=4)), + Span(start=4574, end=4581, + box=Box(l=0.2686971421659869, t=0.49049890508261185, w=0.07810456460532592, h=0.014200429984914883, + page=4)), + Span(start=4582, end=4587, + box=Box(l=0.35061756361754887, t=0.49049890508261185, w=0.03904224057412029, h=0.014200429984914883, + page=4)), + Span(start=4588, end=4591, + box=Box(l=0.39347566103790516, t=0.49049890508261185, w=0.023417310961637943, h=0.014200429984914883, + page=4)), + Span(start=4592, end=4601, + box=Box(l=0.4207088288457791, t=0.49049890508261185, w=0.08254300862121101, h=0.014200429984914883, + page=4)), + Span(start=4602, end=4613, + box=Box(l=0.5070676943132262, t=0.49049890508261185, w=0.09481400090042272, h=0.014200429984914883, + page=4)),] + +list_of_spans_to_merge_2 = [Span(start=30113, end=30119, + box=Box(l=0.12095229775767885, t=0.3578497466414853, w=0.05243790645011725, + h=0.014200429984914883, page=19)), + Span(start=30120, end=30124, + box=Box(l=0.17929474059091924, t=0.3578497466414853, w=0.030687522426571887, + h=0.014200429984914883, + page=19)), + Span(start=30125, end=30129, + box=Box(l=0.21799556239458678, t=0.3578497466414853, w=0.04350076804709073, + h=0.014200429984914883, page=19)), + Span(start=30130, end=30135, + box=Box(l=0.26740086682480063, t=0.3578497466414853, w=0.050208642713631964, + h=0.014200429984914883, + page=19)), + Span(start=30136, end=30141, + box=Box(l=0.32351404592155575, t=0.3578497466414853, w=0.0446254416438761, + h=0.014200429984914883, page=19)), + Span(start=30142, end=30151, + box=Box(l=0.37404402394855496, t=0.3578497466414853, w=0.0769598075514552, + h=0.014200429984914883, page=19)), + Span(start=30152, end=30155, + box=Box(l=0.4569284513402187, t=0.3578497466414853, w=0.029000512031393852, + h=0.014200429984914883, page=19)), + Span(start=30156, end=30165, + box=Box(l=0.4918334997547357, t=0.3578497466414853, w=0.0792091547450259, + h=0.014200429984914883, page=19)), + Span(start=30166, end=30175, + box=Box(l=0.5769471908828846, t=0.3578497466414853, w=0.07175819216632291, + h=0.014200429984914883, page=19)), + Span(start=30176, end=30179, + box=Box(l=0.6576023545380633, t=0.3578497466414853, w=0.03122977576787907, + h=0.014200429984914883, page=19)), + Span(start=30180, end=30184, + box=Box(l=0.6947366666890655, t=0.3578497466414853, w=0.03904224057412024, + h=0.014200429984914883, page=19)), + Span(start=30185, end=30190, + box=Box(l=0.7396834436463088, t=0.3578497466414853, w=0.05020864271363187, + h=0.014200429984914883, page=19)), + Span(start=30191, end=30193, + box=Box(l=0.7957966227430638, t=0.3578497466414853, w=0.015624929612482252, + h=0.014200429984914883, page=19)), + Span(start=30194, end=30197, + box=Box(l=0.12095229775767885, t=0.37500875791374183, w=0.024541984558423317, + h=0.014200429984914883, + page=19)), + Span(start=30198, end=30207, + box=Box(l=0.1518205712980198, t=0.37500875791374183, w=0.07695980755145514, + h=0.014200429984914883, page=19)), + Span(start=30208, end=30210, + box=Box(l=0.2351066678313926, t=0.37500875791374183, w=0.013395665875996984, + h=0.014200429984914883, + page=19)), + Span(start=30211, end=30214, + box=Box(l=0.2548286226893072, t=0.37500875791374183, w=0.02231272082193805, + h=0.014200429984914883, page=19)), + Span(start=30215, end=30217, + box=Box(l=0.283467632493163, t=0.37500875791374183, w=0.015624929612482252, + h=0.014200429984914883, page=19)), + Span(start=30218, end=30221, + box=Box(l=0.3054188510875629, t=0.37500875791374183, w=0.024541984558423317, + h=0.014200429984914883, + page=19)), + Span(start=30222, end=30229, + box=Box(l=0.33628712462790383, t=0.37500875791374183, w=0.055570925755447906, + h=0.014200429984914883, + page=19)), + Span(start=30230, end=30235, + box=Box(l=0.3981843393652693, t=0.37500875791374183, w=0.04183384110899822, + h=0.014200429984914883, page=19)), + Span(start=30236, end=30240, + box=Box(l=0.44668588822663785, t=0.37500875791374183, w=0.03570838669793504, + h=0.014200429984914883, + page=19)), + Span(start=30241, end=30244, + box=Box(l=0.4887205639064905, t=0.37500875791374183, w=0.020083457085452783, + h=0.014200429984914883, + page=19)), + Span(start=30245, end=30255, + box=Box(l=0.5151303099738609, t=0.37500875791374183, w=0.08810612623388145, + h=0.014200429984914883, page=19)), + Span(start=30256, end=30259, + box=Box(l=0.6095627251896601, t=0.37500875791374183, w=0.022312720821938, + h=0.014200429984914883, page=19)), + Span(start=30260, end=30262, + box=Box(l=0.6382017349935157, t=0.37500875791374183, w=0.015624929612482252, + h=0.014200429984914883, + page=19)), + Span(start=30263, end=30268, + box=Box(l=0.6601529535879158, t=0.37500875791374183, w=0.03958449391542752, + h=0.014200429984914883, page=19)), + Span(start=30269, end=30273, + box=Box(l=0.7098795933314969, t=0.37500875791374183, w=0.035708386697935225, + h=0.014200429984914883, + page=19)), + Span(start=30274, end=30276, + box=Box(l=0.7519142690113497, t=0.37500875791374183, w=0.013395665875997033, + h=0.014200429984914883, + page=19)), + Span(start=30277, end=30278, + box=Box(l=0.7716362238692644, t=0.37500875791374183, w=0.008917054945941066, + h=0.014200429984914883, + page=19)), + Span(start=30279, end=30281, + box=Box(l=0.7868795677971232, t=0.37500875791374183, w=0.02454198455842322, + h=0.014200429984914883, page=19)), + Span(start=30282, end=30291, + box=Box(l=0.12095229775767885, t=0.3921677691859983, w=0.08031374488472577, + h=0.014200429984914883, page=19)), + Span(start=30292, end=30296, + box=Box(l=0.2062869069137678, t=0.3921677691859983, w=0.03904224057412024, + h=0.014200429984914883, page=19)), + Span(start=30297, end=30302, + box=Box(l=0.25035001175925126, t=0.3921677691859983, w=0.050208642713631964, + h=0.014200429984914883, + page=19)), + Span(start=30303, end=30311, + box=Box(l=0.30557951874424644, t=0.3921677691859983, w=0.08143841848151108, + h=0.014200429984914883, page=19)), + Span(start=30312, end=30314, + box=Box(l=0.3920388014971207, t=0.3921677691859983, w=0.016729519752182193, + h=0.014200429984914883, page=19)), + Span(start=30315, end=30321, + box=Box(l=0.4137891855206661, t=0.3921677691859983, w=0.0535625800469026, + h=0.014200429984914883, page=19)), + Span(start=30322, end=30328, + box=Box(l=0.47237262983893197, t=0.3921677691859983, w=0.05354249658981717, + h=0.014200429984914883, page=19)), + Span(start=30329, end=30333, + box=Box(l=0.5309359907001122, t=0.3921677691859983, w=0.03681297683763493, + h=0.014200429984914883, page=19)), + Span(start=30334, end=30336, + box=Box(l=0.5727698318091105, t=0.3921677691859983, w=0.01672951975218224, + h=0.014200429984914883, page=19)), + Span(start=30337, end=30344, + box=Box(l=0.5945202158326559, t=0.3921677691859983, w=0.060230287799273016, + h=0.014200429984914883, page=19)), + Span(start=30345, end=30348, + box=Box(l=0.6597713679032922, t=0.3921677691859983, w=0.029000512031393946, + h=0.014200429984914883, page=19)), + Span(start=30349, end=30359, + box=Box(l=0.6937927442060494, t=0.3921677691859983, w=0.07834556609035141, + h=0.014200429984914883, page=19))] def test_merge_spans(): - assert len(list_of_spans_to_merge) == ( - len( - MergeSpans( - list_of_spans_to_merge, 0, 0 - ).merge_neighbor_spans_by_box_coordinate() - ) - ) + assert len(list_of_spans_to_merge) == (len(MergeSpans(list_of_spans_to_merge, 0, 0) + .merge_neighbor_spans_by_box_coordinate())) - assert 4 == len( - MergeSpans( - list_of_spans_to_merge, 0.04387334, 0.01421097 - ).merge_neighbor_spans_by_box_coordinate() - ) + assert 4 == len(MergeSpans(list_of_spans_to_merge, 0.04387334, 0.01421097).merge_neighbor_spans_by_box_coordinate()) merge_spans = MergeSpans(list_of_spans_to_merge_2, 0.04387334, 0.01421097) assert 1 == len(merge_spans.merge_neighbor_spans_by_box_coordinate()) - assert [30113, 30359] == [ - merge_spans.merge_neighbor_spans_by_box_coordinate()[0].start, - merge_spans.merge_neighbor_spans_by_box_coordinate()[0].end, - ] + assert [30113, 30359] == [merge_spans.merge_neighbor_spans_by_box_coordinate()[0].start, merge_spans.merge_neighbor_spans_by_box_coordinate()[0].end] def test_merge_neighbor_spans_by_symbol_distance(): - assert 7 == ( - len( - MergeSpans( - list_of_spans_to_merge, index_distance=10 - ).merge_neighbor_spans_by_symbol_distance() - ) - ) + assert 7 == (len(MergeSpans(list_of_spans_to_merge, index_distance=10) + .merge_neighbor_spans_by_symbol_distance())) + - assert 10 == len( - MergeSpans( - list_of_spans_to_merge, index_distance=1 - ).merge_neighbor_spans_by_symbol_distance() - ) + assert 10 == len(MergeSpans(list_of_spans_to_merge, index_distance=1).merge_neighbor_spans_by_symbol_distance()) list_of_spans_to_merge_2 = [ Span(start=1, end=3, box=Box(l=0.1, t=0.2, w=0.2, h=0.2, page=11)), @@ -867,41 +370,30 @@ def test_from_span_groups_with_box_groups(): list_of_spans_to_merge_in_span_group_format.append( SpanGroup( spans=[Span(start=span.start, end=span.end)], - box_group=BoxGroup(boxes=[span.box]), + box_group=BoxGroup(boxes=[span.box]) ) ) - assert 7 == ( - len( - MergeSpans.from_span_groups_with_box_groups( - list_of_spans_to_merge_in_span_group_format, index_distance=10 - ).merge_neighbor_spans_by_symbol_distance() - ) - ) + assert 7 == (len(MergeSpans.from_span_groups_with_box_groups( + list_of_spans_to_merge_in_span_group_format, + index_distance=10).merge_neighbor_spans_by_symbol_distance()) + ) - assert len(list_of_spans_to_merge) == ( - len( - MergeSpans.from_span_groups_with_box_groups( - list_of_spans_to_merge_in_span_group_format, 0, 0 - ).merge_neighbor_spans_by_box_coordinate() - ) - ) + assert len(list_of_spans_to_merge) == (len(MergeSpans.from_span_groups_with_box_groups( + list_of_spans_to_merge_in_span_group_format, + 0, + 0).merge_neighbor_spans_by_box_coordinate())) def test_box_groups_to_span_groups(): # basic doc annotated with pages and tokens, from pdfplumber parser split at punctuation - with open( - fixture_path / "20fdafb68d0e69d193527a9a1cbe64e7e69a3798__pdfplumber_doc.json", - "r", - ) as f: + with open(fixture_path / "20fdafb68d0e69d193527a9a1cbe64e7e69a3798__pdfplumber_doc.json", "r") as f: raw_json = f.read() fixture_doc_json = json.loads(raw_json) doc = Document.from_json(fixture_doc_json) # boxes drawn neatly around bib entries - with open( - fixture_path / "20fdafb68d0e69d193527a9a1cbe64e7e69a3798__bib_entries.json", "r" - ) as f: + with open(fixture_path / "20fdafb68d0e69d193527a9a1cbe64e7e69a3798__bib_entries.json", "r") as f: raw_json = f.read() fixture_bib_entries_json = json.loads(raw_json)["bib_entries"] @@ -912,28 +404,15 @@ def test_box_groups_to_span_groups(): # generate span_groups with different settings overlap_span_groups = box_groups_to_span_groups(box_groups, doc, center=False) - overlap_at_token_center_span_groups = box_groups_to_span_groups( - box_groups, doc, center=True - ) - overlap_at_token_center_span_groups_x_padded = box_groups_to_span_groups( - box_groups, doc, center=True, pad_x=True - ) - - assert ( - len(box_groups) - == len(overlap_span_groups) - == len(overlap_at_token_center_span_groups) - == len(overlap_at_token_center_span_groups_x_padded) - ) + overlap_at_token_center_span_groups = box_groups_to_span_groups(box_groups, doc, center=True) + overlap_at_token_center_span_groups_x_padded = box_groups_to_span_groups(box_groups, doc, center=True, pad_x=True) + + assert (len(box_groups) == len(overlap_span_groups) == len(overlap_at_token_center_span_groups) == len(overlap_at_token_center_span_groups_x_padded)) # annotate all onto doc to extract texts: doc.annotate(overlap_span_groups=overlap_span_groups) - doc.annotate( - overlap_at_token_center_span_groups=overlap_at_token_center_span_groups - ) - doc.annotate( - overlap_at_token_center_span_groups_x_padded=overlap_at_token_center_span_groups_x_padded - ) + doc.annotate(overlap_at_token_center_span_groups=overlap_at_token_center_span_groups) + doc.annotate(overlap_at_token_center_span_groups_x_padded=overlap_at_token_center_span_groups_x_padded) # when center=False, any token overlap with BoxGroup becomes part of the SpanGroup # in this example, tokens from bib entry '29 and '31' overlap with the box drawn neatly around '30' @@ -965,6 +444,4 @@ def test_box_groups_to_span_groups(): assert doc.overlap_at_token_center_span_groups_x_padded[6].text.startswith("[6]") # original box_group boxes are saved - assert all( - [sg.box_group is not None for sg in doc.overlap_at_token_center_span_groups] - ) + assert all([sg.box_group is not None for sg in doc.overlap_at_token_center_span_groups]) From bb4b740cab58a7abdd4a54566a1bfd23d1bf8392 Mon Sep 17 00:00:00 2001 From: kyleclo Date: Thu, 6 Jul 2023 10:15:04 -0700 Subject: [PATCH 20/25] bugfix; remove repeat method in Span --- src/mmda/types/span.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/mmda/types/span.py b/src/mmda/types/span.py index 98402898..d114c557 100644 --- a/src/mmda/types/span.py +++ b/src/mmda/types/span.py @@ -60,9 +60,6 @@ def small_spans_to_big_span( box=new_box, ) - def is_overlap(self, other: "Span") -> bool: - return self.start < other.end and other.start < self.end - @classmethod def cluster_spans(cls, spans: List["Span"]) -> List[List[int]]: """ From 070ce042164ef220668b067e47f2419ba9087e7d Mon Sep 17 00:00:00 2001 From: kyleclo Date: Thu, 6 Jul 2023 10:49:09 -0700 Subject: [PATCH 21/25] add test for unsorted span merging --- src/mmda/types/span.py | 5 ++--- tests/test_types/test_span.py | 38 ++++++++++++++++++++++++++++------- 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/src/mmda/types/span.py b/src/mmda/types/span.py index d114c557..53cd931a 100644 --- a/src/mmda/types/span.py +++ b/src/mmda/types/span.py @@ -41,8 +41,7 @@ def __lt__(self, other: "Span"): def small_spans_to_big_span( cls, spans: List["Span"], merge_boxes: bool = True ) -> "Span": - # TODO: add warning for unsorted spans or not-contiguous spans - # TODO: what happens when Boxes cant be merged? + # TODO: add warning for non-contiguous spans? start = spans[0].start end = spans[0].end for span in spans[1:]: @@ -50,7 +49,7 @@ def small_spans_to_big_span( start = span.start if span.end > end: end = span.end - if merge_boxes: + if merge_boxes and all(span.box for span in spans): new_box = Box.small_boxes_to_big_box(boxes=[span.box for span in spans]) else: new_box = None diff --git a/tests/test_types/test_span.py b/tests/test_types/test_span.py index 844a5b84..1518a114 100644 --- a/tests/test_types/test_span.py +++ b/tests/test_types/test_span.py @@ -5,9 +5,9 @@ class TestSpan(unittest.TestCase): - def setUp(cls): - cls.span = mmda_span.Span(start=0, end=0) - cls.span_dict = { + def test_to_from_json(self): + span = mmda_span.Span(start=0, end=0) + span_dict = { "start": 0, "end": 8, "box": { @@ -18,10 +18,8 @@ def setUp(cls): "page": 0, }, } - - def test_to_from_json(self): self.assertEqual( - self.span.from_json(self.span_dict).to_json(), + span.from_json(span_dict).to_json(), mmda_span.Span( start=0, end=8, @@ -50,7 +48,33 @@ def test_small_spans_to_big_span(self): mmda_span.Span(start=16, end=24), ] self.assertEqual( - self.span.small_spans_to_big_span(spans=spans, merge_boxes=False), + mmda_span.Span.small_spans_to_big_span(spans=spans, merge_boxes=False), + mmda_span.Span(start=0, end=24), + ) + # if no boxes, should still work + self.assertEqual( + mmda_span.Span.small_spans_to_big_span(spans=spans, merge_boxes=True), + mmda_span.Span(start=0, end=24), + ) + + def test_small_spans_to_big_span_unsorted(self): + spans = [ + mmda_span.Span(start=8, end=16), + mmda_span.Span(start=0, end=8), + mmda_span.Span(start=16, end=24), + ] + self.assertEqual( + mmda_span.Span.small_spans_to_big_span(spans=spans), + mmda_span.Span(start=0, end=24), + ) + + spans = [ + mmda_span.Span(start=16, end=24), + mmda_span.Span(start=8, end=16), + mmda_span.Span(start=0, end=8), + ] + self.assertEqual( + mmda_span.Span.small_spans_to_big_span(spans=spans), mmda_span.Span(start=0, end=24), ) From e8953fce1fe415760fabaf16685a581c1e6ecfad Mon Sep 17 00:00:00 2001 From: kyleclo Date: Thu, 6 Jul 2023 11:07:15 -0700 Subject: [PATCH 22/25] update tools tests so Spans have Boxes --- tests/test_utils/test_tools.py | 1033 ++++++++++++++++++++++++-------- 1 file changed, 790 insertions(+), 243 deletions(-) diff --git a/tests/test_utils/test_tools.py b/tests/test_utils/test_tools.py index 6b39bdb1..59e8f2a9 100644 --- a/tests/test_utils/test_tools.py +++ b/tests/test_utils/test_tools.py @@ -10,20 +10,21 @@ import unittest from mmda.types.annotation import BoxGroup, SpanGroup -from mmda.types.span import Span from mmda.types.box import Box from mmda.types.document import Document - -from mmda.utils.tools import MergeSpans -from mmda.utils.tools import box_groups_to_span_groups +from mmda.types.span import Span +from mmda.utils.tools import MergeSpans, box_groups_to_span_groups fixture_path = pathlib.Path(__file__).parent.parent / "fixtures" / "utils" class TestMergeNeighborSpans(unittest.TestCase): - def test_merge_multiple_neighbor_spans(self): - spans = [Span(start=0, end=10), Span(start=11, end=20), Span(start=21, end=30)] + spans = [ + Span(start=0, end=10, box=Box(l=1.0, t=2.0, w=3.0, h=4.0, page=0)), + Span(start=11, end=20, box=Box(l=5.0, t=6.0, w=7.0, h=8.0, page=0)), + Span(start=21, end=30, box=Box(l=9.0, t=10.0, w=11.0, h=12.0, page=0)), + ] merge_spans = MergeSpans(list_of_spans=spans, index_distance=1) out = merge_spans.merge_neighbor_spans_by_symbol_distance() assert len(out) == 1 @@ -32,18 +33,33 @@ def test_merge_multiple_neighbor_spans(self): assert out[0].end == 30 def test_different_index_distances(self): - spans = [Span(start=0, end=10), Span(start=15, end=20)] + spans = [ + Span(start=0, end=10, box=Box(l=1.0, t=2.0, w=3.0, h=4.0, page=0)), + Span(start=15, end=20, box=Box(l=5.0, t=6.0, w=7.0, h=8.0, page=0)), + ] merge_spans = MergeSpans(list_of_spans=spans, index_distance=1) out = merge_spans.merge_neighbor_spans_by_symbol_distance() - assert out == spans # no merge happened + # no merge happened + for s1, s2 in zip(spans, out): + assert s1.start == s2.start + assert s1.end == s2.end + assert s1.box.coordinates == s2.box.coordinates merge_spans = MergeSpans(list_of_spans=spans, index_distance=2) out = merge_spans.merge_neighbor_spans_by_symbol_distance() - assert out == spans # no merge happened + # no merge happened + for s1, s2 in zip(spans, out): + assert s1.start == s2.start + assert s1.end == s2.end + assert s1.box.coordinates == s2.box.coordinates merge_spans = MergeSpans(list_of_spans=spans, index_distance=4) out = merge_spans.merge_neighbor_spans_by_symbol_distance() - assert out == spans # no merge happened + # no merge happened + for s1, s2 in zip(spans, out): + assert s1.start == s2.start + assert s1.end == s2.end + assert s1.box.coordinates == s2.box.coordinates merge_spans = MergeSpans(list_of_spans=spans, index_distance=5) out = merge_spans.merge_neighbor_spans_by_symbol_distance() @@ -53,8 +69,13 @@ def test_different_index_distances(self): assert out[0].end == 20 def test_zero_index_distance(self): - spans = [Span(start=0, end=10), Span(start=10, end=20)] - out = MergeSpans(list_of_spans=spans, index_distance=0).merge_neighbor_spans_by_symbol_distance() + spans = [ + Span(start=0, end=10, box=Box(l=1.0, t=2.0, w=3.0, h=4.0, page=0)), + Span(start=10, end=20, box=Box(l=5.0, t=6.0, w=7.0, h=8.0, page=0)), + ] + out = MergeSpans( + list_of_spans=spans, index_distance=0 + ).merge_neighbor_spans_by_symbol_distance() assert len(out) == 1 assert isinstance(out[0], Span) assert out[0].start == 0 @@ -64,7 +85,7 @@ def test_handling_of_boxes(self): spans = [ Span(start=0, end=10, box=Box(l=0, t=0, w=1, h=1, page=0)), Span(start=11, end=20, box=Box(l=1, t=1, w=2, h=2, page=0)), - Span(start=21, end=150, box=Box(l=2, t=2, w=3, h=3, page=1)) + Span(start=21, end=150, box=Box(l=2, t=2, w=3, h=3, page=1)), ] merge_spans = MergeSpans(list_of_spans=spans, index_distance=1) merge_spans.merge_neighbor_spans_by_symbol_distance() @@ -77,14 +98,14 @@ def test_handling_of_boxes(self): assert out[0].end == 20 assert out[1].start == 21 assert out[1].end == 150 - assert out[0].box == Box(l=0, t=0, w=3, h=3, page=0) + assert out[0].box.coordinates == Box(l=0, t=0, w=3, h=3, page=0).coordinates # unmerged spans from separate pages keep their original box - assert out[1].box == spans[-1].box + assert out[1].box.coordinates == spans[-1].box.coordinates spans = [ Span(start=0, end=10, box=Box(l=0, t=0, w=1, h=1, page=1)), Span(start=11, end=20, box=Box(l=1, t=1, w=2, h=2, page=1)), - Span(start=100, end=150, box=Box(l=2, t=2, w=3, h=3, page=1)) + Span(start=100, end=150, box=Box(l=2, t=2, w=3, h=3, page=1)), ] merge_spans = MergeSpans(list_of_spans=spans, index_distance=1) @@ -96,15 +117,15 @@ def test_handling_of_boxes(self): assert out[0].end == 20 assert out[1].start == 100 assert out[1].end == 150 - assert out[0].box == Box(l=0, t=0, w=3, h=3, page=1) + assert out[0].box.coordinates == Box(l=0, t=0, w=3, h=3, page=1).coordinates # unmerged spans that were too far apart in symbol distance keep their original box - assert out[1].box == spans[-1].box + assert out[1].box.coordinates == spans[-1].box.coordinates spans = [ Span(start=0, end=10, box=Box(l=0, t=0, w=1, h=1, page=0)), Span(start=11, end=20), Span(start=21, end=150), - Span(start=155, end=200) + Span(start=155, end=200), ] merge_spans = MergeSpans(list_of_spans=spans, index_distance=1) merge_spans.merge_neighbor_spans_by_symbol_distance() @@ -126,223 +147,721 @@ def test_handling_of_boxes(self): list_of_spans_to_merge = [ - Span(start=3944, end=3948, - box=Box(l=0.19238134915568578, t=0.22752901673615306, w=0.06941334053447479, h=0.029442207414270286, - page=4)), - Span(start=3949, end=3951, - box=Box(l=0.27220460878651254, t=0.22752901673615306, w=0.03468585042904468, h=0.029442207414270286, - page=4)), - Span(start=4060, end=4063, - box=Box(l=0.4204075769894973, t=0.34144142726484455, w=0.023417310961637895, h=0.014200429984914883, - page=4)), - Span(start=4072, end=4075, - box=Box(l=0.5182742633669088, t=0.34144142726484455, w=0.029000512031393755, h=0.014200429984914883, - page=4)), - Span(start=4076, end=4083, - box=Box(l=0.5522956396696659, t=0.34144142726484455, w=0.06440764687304719, h=0.014200429984914883, - page=4)), - Span(start=4119, end=4128, - box=Box(l=0.2686971421659869, t=0.36273518298114954, w=0.08479235581478171, h=0.014200429984914883, - page=4)), - Span(start=4134, end=4144, - box=Box(l=0.40387889180816966, t=0.36273518298114954, w=0.08368776567508182, h=0.014200429984914883, - page=4)), - Span(start=4145, end=4148, - box=Box(l=0.4943548659781345, t=0.36273518298114954, w=0.042396177907390975, h=0.014200429984914883, - page=4)), - Span(start=4149, end=4162, - box=Box(l=0.5435392523804085, t=0.36273518298114954, w=0.11491754144296094, h=0.014200429984914883, - page=4)), - Span(start=4166, end=4177, - box=Box(l=0.6876581404256177, t=0.36273518298114954, w=0.09146006356715199, h=0.014200429984914883, - page=4)), - Span(start=4419, end=4427, - box=Box(l=0.2686971421659869, t=0.4479113936500019, w=0.06846450520430858, h=0.014200429984914883, - page=4)), - Span(start=4497, end=4505, - box=Box(l=0.2686971421659869, t=0.46920514936630686, w=0.06846450520430858, h=0.014200429984914883, - page=4)), - Span(start=4517, end=4520, - box=Box(l=0.42195400318507725, t=0.46920514936630686, w=0.029000512031393755, h=0.014200429984914883, - page=4)), - Span(start=4574, end=4581, - box=Box(l=0.2686971421659869, t=0.49049890508261185, w=0.07810456460532592, h=0.014200429984914883, - page=4)), - Span(start=4582, end=4587, - box=Box(l=0.35061756361754887, t=0.49049890508261185, w=0.03904224057412029, h=0.014200429984914883, - page=4)), - Span(start=4588, end=4591, - box=Box(l=0.39347566103790516, t=0.49049890508261185, w=0.023417310961637943, h=0.014200429984914883, - page=4)), - Span(start=4592, end=4601, - box=Box(l=0.4207088288457791, t=0.49049890508261185, w=0.08254300862121101, h=0.014200429984914883, - page=4)), - Span(start=4602, end=4613, - box=Box(l=0.5070676943132262, t=0.49049890508261185, w=0.09481400090042272, h=0.014200429984914883, - page=4)),] - -list_of_spans_to_merge_2 = [Span(start=30113, end=30119, - box=Box(l=0.12095229775767885, t=0.3578497466414853, w=0.05243790645011725, - h=0.014200429984914883, page=19)), - Span(start=30120, end=30124, - box=Box(l=0.17929474059091924, t=0.3578497466414853, w=0.030687522426571887, - h=0.014200429984914883, - page=19)), - Span(start=30125, end=30129, - box=Box(l=0.21799556239458678, t=0.3578497466414853, w=0.04350076804709073, - h=0.014200429984914883, page=19)), - Span(start=30130, end=30135, - box=Box(l=0.26740086682480063, t=0.3578497466414853, w=0.050208642713631964, - h=0.014200429984914883, - page=19)), - Span(start=30136, end=30141, - box=Box(l=0.32351404592155575, t=0.3578497466414853, w=0.0446254416438761, - h=0.014200429984914883, page=19)), - Span(start=30142, end=30151, - box=Box(l=0.37404402394855496, t=0.3578497466414853, w=0.0769598075514552, - h=0.014200429984914883, page=19)), - Span(start=30152, end=30155, - box=Box(l=0.4569284513402187, t=0.3578497466414853, w=0.029000512031393852, - h=0.014200429984914883, page=19)), - Span(start=30156, end=30165, - box=Box(l=0.4918334997547357, t=0.3578497466414853, w=0.0792091547450259, - h=0.014200429984914883, page=19)), - Span(start=30166, end=30175, - box=Box(l=0.5769471908828846, t=0.3578497466414853, w=0.07175819216632291, - h=0.014200429984914883, page=19)), - Span(start=30176, end=30179, - box=Box(l=0.6576023545380633, t=0.3578497466414853, w=0.03122977576787907, - h=0.014200429984914883, page=19)), - Span(start=30180, end=30184, - box=Box(l=0.6947366666890655, t=0.3578497466414853, w=0.03904224057412024, - h=0.014200429984914883, page=19)), - Span(start=30185, end=30190, - box=Box(l=0.7396834436463088, t=0.3578497466414853, w=0.05020864271363187, - h=0.014200429984914883, page=19)), - Span(start=30191, end=30193, - box=Box(l=0.7957966227430638, t=0.3578497466414853, w=0.015624929612482252, - h=0.014200429984914883, page=19)), - Span(start=30194, end=30197, - box=Box(l=0.12095229775767885, t=0.37500875791374183, w=0.024541984558423317, - h=0.014200429984914883, - page=19)), - Span(start=30198, end=30207, - box=Box(l=0.1518205712980198, t=0.37500875791374183, w=0.07695980755145514, - h=0.014200429984914883, page=19)), - Span(start=30208, end=30210, - box=Box(l=0.2351066678313926, t=0.37500875791374183, w=0.013395665875996984, - h=0.014200429984914883, - page=19)), - Span(start=30211, end=30214, - box=Box(l=0.2548286226893072, t=0.37500875791374183, w=0.02231272082193805, - h=0.014200429984914883, page=19)), - Span(start=30215, end=30217, - box=Box(l=0.283467632493163, t=0.37500875791374183, w=0.015624929612482252, - h=0.014200429984914883, page=19)), - Span(start=30218, end=30221, - box=Box(l=0.3054188510875629, t=0.37500875791374183, w=0.024541984558423317, - h=0.014200429984914883, - page=19)), - Span(start=30222, end=30229, - box=Box(l=0.33628712462790383, t=0.37500875791374183, w=0.055570925755447906, - h=0.014200429984914883, - page=19)), - Span(start=30230, end=30235, - box=Box(l=0.3981843393652693, t=0.37500875791374183, w=0.04183384110899822, - h=0.014200429984914883, page=19)), - Span(start=30236, end=30240, - box=Box(l=0.44668588822663785, t=0.37500875791374183, w=0.03570838669793504, - h=0.014200429984914883, - page=19)), - Span(start=30241, end=30244, - box=Box(l=0.4887205639064905, t=0.37500875791374183, w=0.020083457085452783, - h=0.014200429984914883, - page=19)), - Span(start=30245, end=30255, - box=Box(l=0.5151303099738609, t=0.37500875791374183, w=0.08810612623388145, - h=0.014200429984914883, page=19)), - Span(start=30256, end=30259, - box=Box(l=0.6095627251896601, t=0.37500875791374183, w=0.022312720821938, - h=0.014200429984914883, page=19)), - Span(start=30260, end=30262, - box=Box(l=0.6382017349935157, t=0.37500875791374183, w=0.015624929612482252, - h=0.014200429984914883, - page=19)), - Span(start=30263, end=30268, - box=Box(l=0.6601529535879158, t=0.37500875791374183, w=0.03958449391542752, - h=0.014200429984914883, page=19)), - Span(start=30269, end=30273, - box=Box(l=0.7098795933314969, t=0.37500875791374183, w=0.035708386697935225, - h=0.014200429984914883, - page=19)), - Span(start=30274, end=30276, - box=Box(l=0.7519142690113497, t=0.37500875791374183, w=0.013395665875997033, - h=0.014200429984914883, - page=19)), - Span(start=30277, end=30278, - box=Box(l=0.7716362238692644, t=0.37500875791374183, w=0.008917054945941066, - h=0.014200429984914883, - page=19)), - Span(start=30279, end=30281, - box=Box(l=0.7868795677971232, t=0.37500875791374183, w=0.02454198455842322, - h=0.014200429984914883, page=19)), - Span(start=30282, end=30291, - box=Box(l=0.12095229775767885, t=0.3921677691859983, w=0.08031374488472577, - h=0.014200429984914883, page=19)), - Span(start=30292, end=30296, - box=Box(l=0.2062869069137678, t=0.3921677691859983, w=0.03904224057412024, - h=0.014200429984914883, page=19)), - Span(start=30297, end=30302, - box=Box(l=0.25035001175925126, t=0.3921677691859983, w=0.050208642713631964, - h=0.014200429984914883, - page=19)), - Span(start=30303, end=30311, - box=Box(l=0.30557951874424644, t=0.3921677691859983, w=0.08143841848151108, - h=0.014200429984914883, page=19)), - Span(start=30312, end=30314, - box=Box(l=0.3920388014971207, t=0.3921677691859983, w=0.016729519752182193, - h=0.014200429984914883, page=19)), - Span(start=30315, end=30321, - box=Box(l=0.4137891855206661, t=0.3921677691859983, w=0.0535625800469026, - h=0.014200429984914883, page=19)), - Span(start=30322, end=30328, - box=Box(l=0.47237262983893197, t=0.3921677691859983, w=0.05354249658981717, - h=0.014200429984914883, page=19)), - Span(start=30329, end=30333, - box=Box(l=0.5309359907001122, t=0.3921677691859983, w=0.03681297683763493, - h=0.014200429984914883, page=19)), - Span(start=30334, end=30336, - box=Box(l=0.5727698318091105, t=0.3921677691859983, w=0.01672951975218224, - h=0.014200429984914883, page=19)), - Span(start=30337, end=30344, - box=Box(l=0.5945202158326559, t=0.3921677691859983, w=0.060230287799273016, - h=0.014200429984914883, page=19)), - Span(start=30345, end=30348, - box=Box(l=0.6597713679032922, t=0.3921677691859983, w=0.029000512031393946, - h=0.014200429984914883, page=19)), - Span(start=30349, end=30359, - box=Box(l=0.6937927442060494, t=0.3921677691859983, w=0.07834556609035141, - h=0.014200429984914883, page=19))] + Span( + start=3944, + end=3948, + box=Box( + l=0.19238134915568578, + t=0.22752901673615306, + w=0.06941334053447479, + h=0.029442207414270286, + page=4, + ), + ), + Span( + start=3949, + end=3951, + box=Box( + l=0.27220460878651254, + t=0.22752901673615306, + w=0.03468585042904468, + h=0.029442207414270286, + page=4, + ), + ), + Span( + start=4060, + end=4063, + box=Box( + l=0.4204075769894973, + t=0.34144142726484455, + w=0.023417310961637895, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4072, + end=4075, + box=Box( + l=0.5182742633669088, + t=0.34144142726484455, + w=0.029000512031393755, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4076, + end=4083, + box=Box( + l=0.5522956396696659, + t=0.34144142726484455, + w=0.06440764687304719, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4119, + end=4128, + box=Box( + l=0.2686971421659869, + t=0.36273518298114954, + w=0.08479235581478171, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4134, + end=4144, + box=Box( + l=0.40387889180816966, + t=0.36273518298114954, + w=0.08368776567508182, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4145, + end=4148, + box=Box( + l=0.4943548659781345, + t=0.36273518298114954, + w=0.042396177907390975, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4149, + end=4162, + box=Box( + l=0.5435392523804085, + t=0.36273518298114954, + w=0.11491754144296094, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4166, + end=4177, + box=Box( + l=0.6876581404256177, + t=0.36273518298114954, + w=0.09146006356715199, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4419, + end=4427, + box=Box( + l=0.2686971421659869, + t=0.4479113936500019, + w=0.06846450520430858, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4497, + end=4505, + box=Box( + l=0.2686971421659869, + t=0.46920514936630686, + w=0.06846450520430858, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4517, + end=4520, + box=Box( + l=0.42195400318507725, + t=0.46920514936630686, + w=0.029000512031393755, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4574, + end=4581, + box=Box( + l=0.2686971421659869, + t=0.49049890508261185, + w=0.07810456460532592, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4582, + end=4587, + box=Box( + l=0.35061756361754887, + t=0.49049890508261185, + w=0.03904224057412029, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4588, + end=4591, + box=Box( + l=0.39347566103790516, + t=0.49049890508261185, + w=0.023417310961637943, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4592, + end=4601, + box=Box( + l=0.4207088288457791, + t=0.49049890508261185, + w=0.08254300862121101, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4602, + end=4613, + box=Box( + l=0.5070676943132262, + t=0.49049890508261185, + w=0.09481400090042272, + h=0.014200429984914883, + page=4, + ), + ), +] + +list_of_spans_to_merge_2 = [ + Span( + start=30113, + end=30119, + box=Box( + l=0.12095229775767885, + t=0.3578497466414853, + w=0.05243790645011725, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30120, + end=30124, + box=Box( + l=0.17929474059091924, + t=0.3578497466414853, + w=0.030687522426571887, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30125, + end=30129, + box=Box( + l=0.21799556239458678, + t=0.3578497466414853, + w=0.04350076804709073, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30130, + end=30135, + box=Box( + l=0.26740086682480063, + t=0.3578497466414853, + w=0.050208642713631964, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30136, + end=30141, + box=Box( + l=0.32351404592155575, + t=0.3578497466414853, + w=0.0446254416438761, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30142, + end=30151, + box=Box( + l=0.37404402394855496, + t=0.3578497466414853, + w=0.0769598075514552, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30152, + end=30155, + box=Box( + l=0.4569284513402187, + t=0.3578497466414853, + w=0.029000512031393852, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30156, + end=30165, + box=Box( + l=0.4918334997547357, + t=0.3578497466414853, + w=0.0792091547450259, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30166, + end=30175, + box=Box( + l=0.5769471908828846, + t=0.3578497466414853, + w=0.07175819216632291, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30176, + end=30179, + box=Box( + l=0.6576023545380633, + t=0.3578497466414853, + w=0.03122977576787907, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30180, + end=30184, + box=Box( + l=0.6947366666890655, + t=0.3578497466414853, + w=0.03904224057412024, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30185, + end=30190, + box=Box( + l=0.7396834436463088, + t=0.3578497466414853, + w=0.05020864271363187, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30191, + end=30193, + box=Box( + l=0.7957966227430638, + t=0.3578497466414853, + w=0.015624929612482252, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30194, + end=30197, + box=Box( + l=0.12095229775767885, + t=0.37500875791374183, + w=0.024541984558423317, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30198, + end=30207, + box=Box( + l=0.1518205712980198, + t=0.37500875791374183, + w=0.07695980755145514, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30208, + end=30210, + box=Box( + l=0.2351066678313926, + t=0.37500875791374183, + w=0.013395665875996984, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30211, + end=30214, + box=Box( + l=0.2548286226893072, + t=0.37500875791374183, + w=0.02231272082193805, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30215, + end=30217, + box=Box( + l=0.283467632493163, + t=0.37500875791374183, + w=0.015624929612482252, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30218, + end=30221, + box=Box( + l=0.3054188510875629, + t=0.37500875791374183, + w=0.024541984558423317, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30222, + end=30229, + box=Box( + l=0.33628712462790383, + t=0.37500875791374183, + w=0.055570925755447906, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30230, + end=30235, + box=Box( + l=0.3981843393652693, + t=0.37500875791374183, + w=0.04183384110899822, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30236, + end=30240, + box=Box( + l=0.44668588822663785, + t=0.37500875791374183, + w=0.03570838669793504, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30241, + end=30244, + box=Box( + l=0.4887205639064905, + t=0.37500875791374183, + w=0.020083457085452783, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30245, + end=30255, + box=Box( + l=0.5151303099738609, + t=0.37500875791374183, + w=0.08810612623388145, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30256, + end=30259, + box=Box( + l=0.6095627251896601, + t=0.37500875791374183, + w=0.022312720821938, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30260, + end=30262, + box=Box( + l=0.6382017349935157, + t=0.37500875791374183, + w=0.015624929612482252, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30263, + end=30268, + box=Box( + l=0.6601529535879158, + t=0.37500875791374183, + w=0.03958449391542752, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30269, + end=30273, + box=Box( + l=0.7098795933314969, + t=0.37500875791374183, + w=0.035708386697935225, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30274, + end=30276, + box=Box( + l=0.7519142690113497, + t=0.37500875791374183, + w=0.013395665875997033, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30277, + end=30278, + box=Box( + l=0.7716362238692644, + t=0.37500875791374183, + w=0.008917054945941066, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30279, + end=30281, + box=Box( + l=0.7868795677971232, + t=0.37500875791374183, + w=0.02454198455842322, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30282, + end=30291, + box=Box( + l=0.12095229775767885, + t=0.3921677691859983, + w=0.08031374488472577, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30292, + end=30296, + box=Box( + l=0.2062869069137678, + t=0.3921677691859983, + w=0.03904224057412024, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30297, + end=30302, + box=Box( + l=0.25035001175925126, + t=0.3921677691859983, + w=0.050208642713631964, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30303, + end=30311, + box=Box( + l=0.30557951874424644, + t=0.3921677691859983, + w=0.08143841848151108, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30312, + end=30314, + box=Box( + l=0.3920388014971207, + t=0.3921677691859983, + w=0.016729519752182193, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30315, + end=30321, + box=Box( + l=0.4137891855206661, + t=0.3921677691859983, + w=0.0535625800469026, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30322, + end=30328, + box=Box( + l=0.47237262983893197, + t=0.3921677691859983, + w=0.05354249658981717, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30329, + end=30333, + box=Box( + l=0.5309359907001122, + t=0.3921677691859983, + w=0.03681297683763493, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30334, + end=30336, + box=Box( + l=0.5727698318091105, + t=0.3921677691859983, + w=0.01672951975218224, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30337, + end=30344, + box=Box( + l=0.5945202158326559, + t=0.3921677691859983, + w=0.060230287799273016, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30345, + end=30348, + box=Box( + l=0.6597713679032922, + t=0.3921677691859983, + w=0.029000512031393946, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30349, + end=30359, + box=Box( + l=0.6937927442060494, + t=0.3921677691859983, + w=0.07834556609035141, + h=0.014200429984914883, + page=19, + ), + ), +] def test_merge_spans(): - assert len(list_of_spans_to_merge) == (len(MergeSpans(list_of_spans_to_merge, 0, 0) - .merge_neighbor_spans_by_box_coordinate())) + assert len(list_of_spans_to_merge) == ( + len( + MergeSpans( + list_of_spans_to_merge, 0, 0 + ).merge_neighbor_spans_by_box_coordinate() + ) + ) - assert 4 == len(MergeSpans(list_of_spans_to_merge, 0.04387334, 0.01421097).merge_neighbor_spans_by_box_coordinate()) + assert 4 == len( + MergeSpans( + list_of_spans_to_merge, 0.04387334, 0.01421097 + ).merge_neighbor_spans_by_box_coordinate() + ) merge_spans = MergeSpans(list_of_spans_to_merge_2, 0.04387334, 0.01421097) assert 1 == len(merge_spans.merge_neighbor_spans_by_box_coordinate()) - assert [30113, 30359] == [merge_spans.merge_neighbor_spans_by_box_coordinate()[0].start, merge_spans.merge_neighbor_spans_by_box_coordinate()[0].end] + assert [30113, 30359] == [ + merge_spans.merge_neighbor_spans_by_box_coordinate()[0].start, + merge_spans.merge_neighbor_spans_by_box_coordinate()[0].end, + ] def test_merge_neighbor_spans_by_symbol_distance(): - assert 7 == (len(MergeSpans(list_of_spans_to_merge, index_distance=10) - .merge_neighbor_spans_by_symbol_distance())) - + assert 7 == ( + len( + MergeSpans( + list_of_spans_to_merge, index_distance=10 + ).merge_neighbor_spans_by_symbol_distance() + ) + ) - assert 10 == len(MergeSpans(list_of_spans_to_merge, index_distance=1).merge_neighbor_spans_by_symbol_distance()) + assert 10 == len( + MergeSpans( + list_of_spans_to_merge, index_distance=1 + ).merge_neighbor_spans_by_symbol_distance() + ) list_of_spans_to_merge_2 = [ Span(start=1, end=3, box=Box(l=0.1, t=0.2, w=0.2, h=0.2, page=11)), @@ -360,7 +879,9 @@ def test_merge_neighbor_spans_by_symbol_distance(): assert 1 == len(result) assert set([(1, 7)]) == set([(entry.start, entry.end) for entry in result]) - assert [Box(l=0.1, t=0.2, w=0.4, h=0.2, page=11)] == [entry.box for entry in result] + assert [Box(l=0.1, t=0.2, w=0.4, h=0.2, page=11).coordinates] == [ + entry.box.coordinates for entry in result + ] def test_from_span_groups_with_box_groups(): @@ -370,30 +891,41 @@ def test_from_span_groups_with_box_groups(): list_of_spans_to_merge_in_span_group_format.append( SpanGroup( spans=[Span(start=span.start, end=span.end)], - box_group=BoxGroup(boxes=[span.box]) + box_group=BoxGroup(boxes=[span.box]), ) ) - assert 7 == (len(MergeSpans.from_span_groups_with_box_groups( - list_of_spans_to_merge_in_span_group_format, - index_distance=10).merge_neighbor_spans_by_symbol_distance()) - ) + assert 7 == ( + len( + MergeSpans.from_span_groups_with_box_groups( + list_of_spans_to_merge_in_span_group_format, index_distance=10 + ).merge_neighbor_spans_by_symbol_distance() + ) + ) - assert len(list_of_spans_to_merge) == (len(MergeSpans.from_span_groups_with_box_groups( - list_of_spans_to_merge_in_span_group_format, - 0, - 0).merge_neighbor_spans_by_box_coordinate())) + assert len(list_of_spans_to_merge) == ( + len( + MergeSpans.from_span_groups_with_box_groups( + list_of_spans_to_merge_in_span_group_format, 0, 0 + ).merge_neighbor_spans_by_box_coordinate() + ) + ) def test_box_groups_to_span_groups(): # basic doc annotated with pages and tokens, from pdfplumber parser split at punctuation - with open(fixture_path / "20fdafb68d0e69d193527a9a1cbe64e7e69a3798__pdfplumber_doc.json", "r") as f: + with open( + fixture_path / "20fdafb68d0e69d193527a9a1cbe64e7e69a3798__pdfplumber_doc.json", + "r", + ) as f: raw_json = f.read() fixture_doc_json = json.loads(raw_json) doc = Document.from_json(fixture_doc_json) # boxes drawn neatly around bib entries - with open(fixture_path / "20fdafb68d0e69d193527a9a1cbe64e7e69a3798__bib_entries.json", "r") as f: + with open( + fixture_path / "20fdafb68d0e69d193527a9a1cbe64e7e69a3798__bib_entries.json", "r" + ) as f: raw_json = f.read() fixture_bib_entries_json = json.loads(raw_json)["bib_entries"] @@ -404,15 +936,28 @@ def test_box_groups_to_span_groups(): # generate span_groups with different settings overlap_span_groups = box_groups_to_span_groups(box_groups, doc, center=False) - overlap_at_token_center_span_groups = box_groups_to_span_groups(box_groups, doc, center=True) - overlap_at_token_center_span_groups_x_padded = box_groups_to_span_groups(box_groups, doc, center=True, pad_x=True) - - assert (len(box_groups) == len(overlap_span_groups) == len(overlap_at_token_center_span_groups) == len(overlap_at_token_center_span_groups_x_padded)) + overlap_at_token_center_span_groups = box_groups_to_span_groups( + box_groups, doc, center=True + ) + overlap_at_token_center_span_groups_x_padded = box_groups_to_span_groups( + box_groups, doc, center=True, pad_x=True + ) + + assert ( + len(box_groups) + == len(overlap_span_groups) + == len(overlap_at_token_center_span_groups) + == len(overlap_at_token_center_span_groups_x_padded) + ) # annotate all onto doc to extract texts: doc.annotate(overlap_span_groups=overlap_span_groups) - doc.annotate(overlap_at_token_center_span_groups=overlap_at_token_center_span_groups) - doc.annotate(overlap_at_token_center_span_groups_x_padded=overlap_at_token_center_span_groups_x_padded) + doc.annotate( + overlap_at_token_center_span_groups=overlap_at_token_center_span_groups + ) + doc.annotate( + overlap_at_token_center_span_groups_x_padded=overlap_at_token_center_span_groups_x_padded + ) # when center=False, any token overlap with BoxGroup becomes part of the SpanGroup # in this example, tokens from bib entry '29 and '31' overlap with the box drawn neatly around '30' @@ -444,4 +989,6 @@ def test_box_groups_to_span_groups(): assert doc.overlap_at_token_center_span_groups_x_padded[6].text.startswith("[6]") # original box_group boxes are saved - assert all([sg.box_group is not None for sg in doc.overlap_at_token_center_span_groups]) + assert all( + [sg.box_group is not None for sg in doc.overlap_at_token_center_span_groups] + ) From dc70d9aa370caf0a8230821dcb55ebfeeb707d4d Mon Sep 17 00:00:00 2001 From: kyleclo Date: Thu, 6 Jul 2023 11:23:07 -0700 Subject: [PATCH 23/25] fix bug in tools; should just use the librarys span merging function --- src/mmda/utils/tools.py | 87 ++++++++++++++++++++++++----------------- 1 file changed, 52 insertions(+), 35 deletions(-) diff --git a/src/mmda/utils/tools.py b/src/mmda/utils/tools.py index 9effac49..96394113 100644 --- a/src/mmda/utils/tools.py +++ b/src/mmda/utils/tools.py @@ -1,10 +1,10 @@ from __future__ import annotations +import itertools import logging from collections import defaultdict from itertools import groupby -import itertools -from typing import List, Dict, Tuple +from typing import Dict, List, Tuple import numpy as np @@ -14,7 +14,12 @@ def allocate_overlapping_tokens_for_box( - tokens: List[SpanGroup], box, token_box_in_box_group: bool = False, x: float = 0.0, y: float = 0.0, center: bool = False + tokens: List[SpanGroup], + box, + token_box_in_box_group: bool = False, + x: float = 0.0, + y: float = 0.0, + center: bool = False, ) -> Tuple[List[Span], List[Span]]: """Finds overlap of tokens for given box Args @@ -29,10 +34,14 @@ def allocate_overlapping_tokens_for_box( """ allocated_tokens, remaining_tokens = [], [] for token in tokens: - if token_box_in_box_group and token.box_group.boxes[0].is_overlap(other=box, x=x, y=y, center=center): + if token_box_in_box_group and token.box_group.boxes[0].is_overlap( + other=box, x=x, y=y, center=center + ): # The token "box" is stored within the SpanGroup's .box_group allocated_tokens.append(token) - elif token.spans[0].box is not None and token.spans[0].box.is_overlap(other=box, x=x, y=y, center=center): + elif token.spans[0].box is not None and token.spans[0].box.is_overlap( + other=box, x=x, y=y, center=center + ): # default to assuming the token "box" is stored in the SpanGroup .box allocated_tokens.append(token) else: @@ -41,7 +50,7 @@ def allocate_overlapping_tokens_for_box( def box_groups_to_span_groups( - box_groups: List[BoxGroup], doc: Document, pad_x: bool = False, center: bool = False + box_groups: List[BoxGroup], doc: Document, pad_x: bool = False, center: bool = False ) -> List[SpanGroup]: """Generate SpanGroups from BoxGroups. Args @@ -60,23 +69,22 @@ def box_groups_to_span_groups( token_box_in_box_group = None for box_id, box_group in enumerate(box_groups): - all_tokens_overlapping_box_group = [] for box in box_group.boxes: - # Caching the page tokens to avoid duplicated search if box.page not in all_page_tokens: - cur_page_tokens = all_page_tokens[box.page] = doc.pages[ - box.page - ].tokens + cur_page_tokens = all_page_tokens[box.page] = doc.pages[box.page].tokens if token_box_in_box_group is None: # Determine whether box is stored on token SpanGroup span.box or in the box_group token_box_in_box_group = all( [ ( - (hasattr(token.box_group, "boxes") and len(token.box_group.boxes) == 1) - and token.spans[0].box is None + ( + hasattr(token.box_group, "boxes") + and len(token.box_group.boxes) == 1 + ) + and token.spans[0].box is None ) for token in cur_page_tokens ] @@ -84,9 +92,16 @@ def box_groups_to_span_groups( # Determine average width of tokens on this page if we are going to pad x if pad_x: if token_box_in_box_group and box.page not in avg_token_widths: - avg_token_widths[box.page] = np.average([t.box_group.boxes[0].w for t in cur_page_tokens]) - elif not token_box_in_box_group and box.page not in avg_token_widths: - avg_token_widths[box.page] = np.average([t.spans[0].box.w for t in cur_page_tokens]) + avg_token_widths[box.page] = np.average( + [t.box_group.boxes[0].w for t in cur_page_tokens] + ) + elif ( + not token_box_in_box_group + and box.page not in avg_token_widths + ): + avg_token_widths[box.page] = np.average( + [t.spans[0].box.w for t in cur_page_tokens] + ) else: cur_page_tokens = all_page_tokens[box.page] @@ -99,7 +114,7 @@ def box_groups_to_span_groups( # optionally pad x a small amount so that extra narrow token boxes (when split at punctuation) are not missed x=avg_token_widths.get(box.page, 0.0) * 0.5 if pad_x else 0.0, y=0.0, - center=center + center=center, ) all_page_tokens[box.page] = remaining_tokens @@ -113,7 +128,8 @@ def box_groups_to_span_groups( else MergeSpans( list_of_spans=list( itertools.chain.from_iterable( - span_group.spans for span_group in all_tokens_overlapping_box_group + span_group.spans + for span_group in all_tokens_overlapping_box_group ) ), index_distance=1, @@ -131,9 +147,11 @@ def box_groups_to_span_groups( ) if not token_box_in_box_group: - logging.warning("tokens with box stored in SpanGroup span.box will be deprecated (that is, " - "future Spans wont contain box). Ensure Document is annotated with tokens " - "having box stored in SpanGroup box_group.boxes") + logging.warning( + "tokens with box stored in SpanGroup span.box will be deprecated (that is, " + "future Spans wont contain box). Ensure Document is annotated with tokens " + "having box stored in SpanGroup box_group.boxes" + ) del all_page_tokens @@ -150,6 +168,7 @@ def box_groups_to_span_groups( # ) return derived_span_groups + class MergeSpans: """ Given w=width and h=height merge neighboring spans which are w, h or less apart or by merging neighboring spans @@ -201,20 +220,24 @@ def build_graph_index_overlap(self): """ starts_matrix = np.full( (len(self.list_of_spans), len(self.list_of_spans)), - [span.start for span in self.list_of_spans] + [span.start for span in self.list_of_spans], ) ends_matrix = np.full( (len(self.list_of_spans), len(self.list_of_spans)), - [span.end for span in self.list_of_spans] + [span.end for span in self.list_of_spans], ) starts_minus_ends = np.abs(starts_matrix - ends_matrix.T) ends_minus_starts = np.abs(ends_matrix - starts_matrix.T) - are_neighboring_spans = np.minimum(starts_minus_ends, ends_minus_starts) <= self.index_distance - neighboring_spans = np.transpose(are_neighboring_spans.nonzero()) + are_neighboring_spans = ( + np.minimum(starts_minus_ends, ends_minus_starts) <= self.index_distance + ) + neighboring_spans = np.transpose(are_neighboring_spans.nonzero()) if len(neighboring_spans) > 0: - neighboring_spans_no_dupes = neighboring_spans[np.where(neighboring_spans[:,1] < neighboring_spans[:,0])] + neighboring_spans_no_dupes = neighboring_spans[ + np.where(neighboring_spans[:, 1] < neighboring_spans[:, 0]) + ] for j, i in neighboring_spans_no_dupes: span_i = self.list_of_spans[i] @@ -305,14 +328,8 @@ def build_merged_spans_from_connected_components(self, index): for span in page_spans: spans_by_page[pg].append(span) for page_spans in spans_by_page.values(): - merged_box = Box.small_boxes_to_big_box( - [span.box for span in page_spans] - ) - merged_spans.append( - Span( - start=min([span.start for span in page_spans]), - end=max([span.end for span in page_spans]), - box=merged_box, - ) + merged_span = Span.small_spans_to_big_span( + spans=page_spans, merge_boxes=True ) + merged_spans.append(merged_span) return merged_spans From 1069e522a000c3d50ca3c80fb0e1fad15b2b4f79 Mon Sep 17 00:00:00 2001 From: kyleclo Date: Thu, 6 Jul 2023 11:25:50 -0700 Subject: [PATCH 24/25] reformat test_tools; bugfix need to check Box equality via .coordinates --- tests/test_utils/test_tools.py | 34 ++++++---------------------------- 1 file changed, 6 insertions(+), 28 deletions(-) diff --git a/tests/test_utils/test_tools.py b/tests/test_utils/test_tools.py index 59e8f2a9..a35645a7 100644 --- a/tests/test_utils/test_tools.py +++ b/tests/test_utils/test_tools.py @@ -20,11 +20,7 @@ class TestMergeNeighborSpans(unittest.TestCase): def test_merge_multiple_neighbor_spans(self): - spans = [ - Span(start=0, end=10, box=Box(l=1.0, t=2.0, w=3.0, h=4.0, page=0)), - Span(start=11, end=20, box=Box(l=5.0, t=6.0, w=7.0, h=8.0, page=0)), - Span(start=21, end=30, box=Box(l=9.0, t=10.0, w=11.0, h=12.0, page=0)), - ] + spans = [Span(start=0, end=10), Span(start=11, end=20), Span(start=21, end=30)] merge_spans = MergeSpans(list_of_spans=spans, index_distance=1) out = merge_spans.merge_neighbor_spans_by_symbol_distance() assert len(out) == 1 @@ -33,33 +29,18 @@ def test_merge_multiple_neighbor_spans(self): assert out[0].end == 30 def test_different_index_distances(self): - spans = [ - Span(start=0, end=10, box=Box(l=1.0, t=2.0, w=3.0, h=4.0, page=0)), - Span(start=15, end=20, box=Box(l=5.0, t=6.0, w=7.0, h=8.0, page=0)), - ] + spans = [Span(start=0, end=10), Span(start=15, end=20)] merge_spans = MergeSpans(list_of_spans=spans, index_distance=1) out = merge_spans.merge_neighbor_spans_by_symbol_distance() - # no merge happened - for s1, s2 in zip(spans, out): - assert s1.start == s2.start - assert s1.end == s2.end - assert s1.box.coordinates == s2.box.coordinates + assert out == spans # no merge happened merge_spans = MergeSpans(list_of_spans=spans, index_distance=2) out = merge_spans.merge_neighbor_spans_by_symbol_distance() - # no merge happened - for s1, s2 in zip(spans, out): - assert s1.start == s2.start - assert s1.end == s2.end - assert s1.box.coordinates == s2.box.coordinates + assert out == spans # no merge happened merge_spans = MergeSpans(list_of_spans=spans, index_distance=4) out = merge_spans.merge_neighbor_spans_by_symbol_distance() - # no merge happened - for s1, s2 in zip(spans, out): - assert s1.start == s2.start - assert s1.end == s2.end - assert s1.box.coordinates == s2.box.coordinates + assert out == spans # no merge happened merge_spans = MergeSpans(list_of_spans=spans, index_distance=5) out = merge_spans.merge_neighbor_spans_by_symbol_distance() @@ -69,10 +50,7 @@ def test_different_index_distances(self): assert out[0].end == 20 def test_zero_index_distance(self): - spans = [ - Span(start=0, end=10, box=Box(l=1.0, t=2.0, w=3.0, h=4.0, page=0)), - Span(start=10, end=20, box=Box(l=5.0, t=6.0, w=7.0, h=8.0, page=0)), - ] + spans = [Span(start=0, end=10), Span(start=10, end=20)] out = MergeSpans( list_of_spans=spans, index_distance=0 ).merge_neighbor_spans_by_symbol_distance() From 1d546cf787951be712673bdab5e2dda0c4578dc9 Mon Sep 17 00:00:00 2001 From: kyleclo Date: Thu, 6 Jul 2023 11:39:30 -0700 Subject: [PATCH 25/25] minor fix in test; equivalence of Span and Box --- tests/test_internal_ai2/test_api.py | 99 +++++++++++----------- tests/test_recipes/core_recipe_fixtures.py | 35 -------- 2 files changed, 51 insertions(+), 83 deletions(-) diff --git a/tests/test_internal_ai2/test_api.py b/tests/test_internal_ai2/test_api.py index 404f2940..37c2c9cb 100644 --- a/tests/test_internal_ai2/test_api.py +++ b/tests/test_internal_ai2/test_api.py @@ -20,104 +20,107 @@ class ClassificationSpanGroup(mmda_api.SpanGroup): class TestApi(unittest.TestCase): def test_vanilla_span_group(self) -> None: - sg_ann = mmda_ann.SpanGroup.from_json({ - 'spans': [{'start': 0, 'end': 1}], - 'id': 1, - 'metadata': {'text': 'hello', 'id': 999} # note id not used; it's just in metadata - }) + sg_ann = mmda_ann.SpanGroup.from_json( + { + "spans": [{"start": 0, "end": 1}], + "id": 1, + "metadata": { + "text": "hello", + "id": 999, + }, # note id not used; it's just in metadata + } + ) sg_api = mmda_api.SpanGroup.from_mmda(sg_ann) - self.assertEqual(sg_api.text, 'hello') + self.assertEqual(sg_api.text, "hello") self.assertEqual(sg_api.id, 1) self.assertEqual(sg_api.attributes.dict(), {}) def test_classification_span_group(self) -> None: - sg_ann = mmda_ann.SpanGroup.from_json({ - 'spans': [{'start': 0, 'end': 1}], - 'metadata': {'text': 'hello', 'id': 1} - }) + sg_ann = mmda_ann.SpanGroup.from_json( + {"spans": [{"start": 0, "end": 1}], "metadata": {"text": "hello", "id": 1}} + ) with self.assertRaises(ValidationError): # this should fail because metadata is missing label # and confidence ClassificationSpanGroup.from_mmda(sg_ann) - sg_ann.metadata.label = 'label' + sg_ann.metadata.label = "label" sg_ann.metadata.score = 0.5 sg_api = ClassificationSpanGroup.from_mmda(sg_ann) - self.assertEqual( - sg_api.attributes.dict(), {'label': 'label', 'score': 0.5} - ) + self.assertEqual(sg_api.attributes.dict(), {"label": "label", "score": 0.5}) # extra field should just get ignored - sg_ann.metadata.extra = 'extra' - self.assertEqual( - sg_api.attributes.dict(), {'label': 'label', 'score': 0.5} - ) + sg_ann.metadata.extra = "extra" + self.assertEqual(sg_api.attributes.dict(), {"label": "label", "score": 0.5}) with self.assertRaises(ValidationError): # this should fail bc score is not a float - sg_ann.metadata.score = 'not a float' + sg_ann.metadata.score = "not a float" ClassificationSpanGroup.from_mmda(sg_ann) def test_equivalence(self): - sg_ann = mmda_ann.SpanGroup.from_json({ - 'spans': [{'start': 0, 'end': 1}], - 'metadata': {'label': 'label', 'score': 0.5} - }) + sg_ann = mmda_ann.SpanGroup.from_json( + { + "spans": [{"start": 0, "end": 1}], + "metadata": {"label": "label", "score": 0.5}, + } + ) sg_ann_2 = ClassificationSpanGroup.from_mmda(sg_ann).to_mmda() self.assertDictEqual(sg_ann.to_json(), sg_ann_2.to_json()) self.assertDictEqual(sg_ann.__dict__, sg_ann_2.__dict__) - def test_box(self): box = mmda_api.Box(left=0.1, top=0.1, width=0.1, height=0.1, page=0) - assert box.to_mmda() == mmdaBox(l=0.1, t=0.1, w=0.1, h=0.1, page=0) + assert ( + box.to_mmda().coordinates + == mmdaBox(l=0.1, t=0.1, w=0.1, h=0.1, page=0).coordinates + ) assert mmda_api.Box.from_mmda(box.to_mmda()) == box def test_span(self): - span = mmda_api.Span(start=0, end=1, box=mmda_api.Box(left=0.1, top=0.1, width=0.1, height=0.1, page=0)) - assert span.to_mmda() == mmdaSpan(start=0, end=1, box=mmdaBox(l=0.1, t=0.1, w=0.1, h=0.1, page=0)) + span = mmda_api.Span( + start=0, + end=1, + box=mmda_api.Box(left=0.1, top=0.1, width=0.1, height=0.1, page=0), + ) + assert ( + span.to_mmda().to_json() + == mmdaSpan( + start=0, end=1, box=mmdaBox(l=0.1, t=0.1, w=0.1, h=0.1, page=0) + ).to_json() + ) def test_box_group(self): box_group = mmda_api.BoxGroup( - boxes=[ - mmda_api.Box(left=0.1, top=0.1, width=0.1, height=0.1, page=0) - ], + boxes=[mmda_api.Box(left=0.1, top=0.1, width=0.1, height=0.1, page=0)], id=0, - type='test', + type="test", # these attributes are going to be discarded because # BoxGroup is using the default Attributes class - attributes={'one': 'Test string'} + attributes={"one": "Test string"}, ) - self.assertEqual( - mmda_api.BoxGroup.from_mmda(box_group.to_mmda()), - box_group - ) + self.assertEqual(mmda_api.BoxGroup.from_mmda(box_group.to_mmda()), box_group) def test_span_group(self): box_group = mmda_api.BoxGroup( - boxes=[ - mmda_api.Box(left=0.1, top=0.1, width=0.1, height=0.1, page=0) - ], + boxes=[mmda_api.Box(left=0.1, top=0.1, width=0.1, height=0.1, page=0)], id=0, - type='test', - attributes={'one': 'Test string'} + type="test", + attributes={"one": "Test string"}, ) span_group = mmda_api.SpanGroup( spans=[], box_group=box_group, - attributes={'one': 'Test string'}, + attributes={"one": "Test string"}, id=0, - type='test', - text='this is a test' + type="test", + text="this is a test", ) - self.assertEqual( - mmda_api.SpanGroup.from_mmda(span_group.to_mmda()), - span_group - ) + self.assertEqual(mmda_api.SpanGroup.from_mmda(span_group.to_mmda()), span_group) diff --git a/tests/test_recipes/core_recipe_fixtures.py b/tests/test_recipes/core_recipe_fixtures.py index 0f676847..fd0015fe 100644 --- a/tests/test_recipes/core_recipe_fixtures.py +++ b/tests/test_recipes/core_recipe_fixtures.py @@ -439,13 +439,6 @@ { "start": 3370, "end": 3372, - "box": { - "left": 0.7474297722689077, - "top": 0.7828144700712589, - "width": 0.017234544537815144, - "height": 0.012956175771971501, - "page": 0, - }, } ], "id": 895, @@ -456,13 +449,6 @@ { "start": 3373, "end": 3382, - "box": { - "left": 0.7730432389915969, - "top": 0.7828144700712589, - "width": 0.0749152648739495, - "height": 0.012956175771971501, - "page": 0, - }, } ], "id": 896, @@ -473,13 +459,6 @@ { "start": 3383, "end": 3394, - "box": { - "left": 0.5165052100840336, - "top": 0.7828144700712589, - "width": 0.366509090756303, - "height": 0.029060926365795714, - "page": 0, - }, } ], "id": 897, @@ -490,13 +469,6 @@ { "start": 3395, "end": 3405, - "box": { - "left": 0.5679338243697479, - "top": 0.798919220665083, - "width": 0.07995728588235285, - "height": 0.012956175771971612, - "page": 0, - }, } ], "id": 898, @@ -507,13 +479,6 @@ { "start": 3406, "end": 3408, - "box": { - "left": 0.6542532240336135, - "top": 0.798919220665083, - "width": 0.018242948739495835, - "height": 0.012956175771971612, - "page": 0, - }, } ], "id": 899,