diff --git a/src/mmda/eval/metrics.py b/src/mmda/eval/metrics.py index abc431f9..dfdd36ce 100644 --- a/src/mmda/eval/metrics.py +++ b/src/mmda/eval/metrics.py @@ -75,11 +75,11 @@ def levenshtein( def box_overlap(box: Box, container: Box) -> float: """Returns the percentage of area of a box inside of a container.""" - bl, bt, bw, bh = box.xywh + bl, bt, bw, bh = box.l, box.t, box.w, box.h br = bl + bw bb = bt + bh - cl, ct, cw, ch = container.xywh + cl, ct, cw, ch = container.l, container.t, container.w, container.h cr = cl + cw cb = ct + ch diff --git a/src/mmda/predictors/sklearn_predictors/svm_word_predictor.py b/src/mmda/predictors/sklearn_predictors/svm_word_predictor.py index 13bdb2f6..bc926489 100644 --- a/src/mmda/predictors/sklearn_predictors/svm_word_predictor.py +++ b/src/mmda/predictors/sklearn_predictors/svm_word_predictor.py @@ -451,7 +451,10 @@ def _create_words( else: spans = [ Span.small_spans_to_big_span( - spans=[span for token in tokens_in_word for span in token.spans] + spans=[ + span for token in tokens_in_word for span in token.spans + ], + merge_boxes=False, ) ] metadata = ( @@ -471,7 +474,8 @@ def _create_words( # last bit spans = [ Span.small_spans_to_big_span( - spans=[span for token in tokens_in_word for span in token.spans] + spans=[span for token in tokens_in_word for span in token.spans], + merge_boxes=False, ) ] metadata = ( diff --git a/src/mmda/types/annotation.py b/src/mmda/types/annotation.py index 4857df5c..240d02c0 100644 --- a/src/mmda/types/annotation.py +++ b/src/mmda/types/annotation.py @@ -22,7 +22,6 @@ __all__ = ["Annotation", "BoxGroup", "SpanGroup", "Relation"] - def warn_deepcopy_of_annotation(obj: "Annotation") -> None: """Warns when a deepcopy is performed on an Annotation.""" @@ -34,15 +33,14 @@ def warn_deepcopy_of_annotation(obj: "Annotation") -> None: warnings.warn(msg, UserWarning, stacklevel=2) - class Annotation: """Annotation is intended for storing model predictions for a document.""" def __init__( - self, - id: Optional[int] = None, - doc: Optional['Document'] = None, - metadata: Optional[Metadata] = None + self, + id: Optional[int] = None, + doc: Optional["Document"] = None, + metadata: Optional[Metadata] = None, ): self.id = id self.doc = doc @@ -77,23 +75,30 @@ def __getattr__(self, field: str) -> List["Annotation"]: return self.__getattribute__(field) - class BoxGroup(Annotation): def __init__( - self, - boxes: List[Box], - id: Optional[int] = None, - doc: Optional['Document'] = None, - metadata: Optional[Metadata] = None, + self, + boxes: List[Box], + id: Optional[int] = None, + doc: Optional["Document"] = None, + metadata: Optional[Metadata] = None, + allow_overlap: Optional[bool] = False, ): self.boxes = boxes + if not allow_overlap: + clusters = Box.cluster_boxes(boxes=boxes) + if any([len(cluster) > 1 for cluster in clusters]): + raise ValueError( + "BoxGroup does not allow overlapping boxes. " + "Consider setting allow_overlap=True." + ) super().__init__(id=id, doc=doc, metadata=metadata) def to_json(self) -> Dict: box_group_dict = dict( boxes=[box.to_json() for box in self.boxes], id=self.id, - metadata=self.metadata.to_json() + metadata=self.metadata.to_json(), ) return { key: value for key, value in box_group_dict.items() if value @@ -101,16 +106,13 @@ def to_json(self) -> Dict: @classmethod def from_json(cls, box_group_dict: Dict) -> "BoxGroup": - if "metadata" in box_group_dict: metadata_dict = box_group_dict["metadata"] else: # this fallback is necessary to ensure compatibility with box # groups that were create before the metadata migration and # therefore have "type" in the root of the json dict instead. - metadata_dict = { - "type": box_group_dict.get("type", None) - } + metadata_dict = {"type": box_group_dict.get("type", None)} return cls( boxes=[ @@ -132,7 +134,7 @@ def __deepcopy__(self, memo): box_group = BoxGroup( boxes=deepcopy(self.boxes, memo), id=self.id, - metadata=deepcopy(self.metadata, memo) + metadata=deepcopy(self.metadata, memo), ) # Don't copy an attached document @@ -150,25 +152,30 @@ def type(self, type: Union[str, None]) -> None: class SpanGroup(Annotation): - def __init__( - self, - spans: List[Span], - box_group: Optional[BoxGroup] = None, - id: Optional[int] = None, - doc: Optional['Document'] = None, - metadata: Optional[Metadata] = None, + self, + spans: List[Span], + box_group: Optional[BoxGroup] = None, + id: Optional[int] = None, + doc: Optional["Document"] = None, + metadata: Optional[Metadata] = None, + allow_overlap: Optional[bool] = False, ): self.spans = spans + if not allow_overlap: + clusters = Span.cluster_spans(spans=spans) + if any([len(cluster) > 1 for cluster in clusters]): + raise ValueError( + "SpanGroup does not allow overlapping spans. " + "Consider setting allow_overlap=True." + ) self.box_group = box_group super().__init__(id=id, doc=doc, metadata=metadata) @property def symbols(self) -> List[str]: if self.doc is not None: - return [ - self.doc.symbols[span.start: span.end] for span in self.spans - ] + return [self.doc.symbols[span.start : span.end] for span in self.spans] else: return [] @@ -187,12 +194,10 @@ def to_json(self) -> Dict: spans=[span.to_json() for span in self.spans], id=self.id, metadata=self.metadata.to_json(), - box_group=self.box_group.to_json() if self.box_group else None + box_group=self.box_group.to_json() if self.box_group else None, ) return { - key: value - for key, value in span_group_dict.items() - if value is not None + key: value for key, value in span_group_dict.items() if value is not None } # only serialize non-null values @classmethod @@ -211,7 +216,7 @@ def from_json(cls, span_group_dict: Dict) -> "SpanGroup": # therefore have "id", "type" in the root of the json dict instead. metadata_dict = { "type": span_group_dict.get("type", None), - "text": span_group_dict.get("text", None) + "text": span_group_dict.get("text", None), } return cls( @@ -256,7 +261,7 @@ def __deepcopy__(self, memo): spans=deepcopy(self.spans, memo), id=self.id, metadata=deepcopy(self.metadata, memo), - box_group=deepcopy(self.box_group, memo) + box_group=deepcopy(self.box_group, memo), ) # Don't copy an attached document @@ -284,6 +289,5 @@ def text(self, text: Union[str, None]) -> None: self.metadata.text = text - class Relation(Annotation): - pass \ No newline at end of file + pass diff --git a/src/mmda/types/box.py b/src/mmda/types/box.py index c1bfd4a9..c7b96402 100644 --- a/src/mmda/types/box.py +++ b/src/mmda/types/box.py @@ -1,39 +1,64 @@ """ +A Box on a page. Can be in relative or absolute coordinates. +@kylel """ -from typing import List, Dict, Tuple, Union -from dataclasses import dataclass +import logging import warnings +from dataclasses import dataclass +from typing import Dict, List, Tuple, Union + import numpy as np -def is_overlap_1d(start1: float, end1: float, start2: float, end2: float, x: float = 0) -> bool: +def is_overlap_1d( + start1: float, end1: float, start2: float, end2: float, x: float = 0 +) -> bool: """Return whether two 1D intervals overlaps given x""" assert start1 <= end1 assert start2 <= end2 - return not (start1 - x > end2 or start1 > end2 + x or end1 + x < start2 or end1 < start2 - x) # ll # rr + return not ( + start1 - x > end2 or start1 > end2 + x or end1 + x < start2 or end1 < start2 - x + ) # ll # rr -@dataclass class Box: - l: float - t: float - w: float - h: float - page: int + def __init__(self, l: float, t: float, w: float, h: float, page: int) -> None: + if w < 0 or h < 0: + raise ValueError(f"Width and height must be non-negative, got {w} and {h}") + if page < 0: + raise ValueError(f"Page must be non-negative, got {page}") + if l < 0 or t < 0: + raise ValueError(f"Left and top must be non-negative, got {l} and {t}") + self.l = l + self.t = t + self.w = w + self.h = h + self.page = page def to_json(self) -> Dict[str, float]: - return {'left': self.l, 'top': self.t, 'width': self.w, 'height': self.h, 'page': self.page} + return { + "left": self.l, + "top": self.t, + "width": self.w, + "height": self.h, + "page": self.page, + } @classmethod def from_json(cls, box_dict: Dict[str, Union[float, int]]) -> "Box": - return Box(l=box_dict['left'], t=box_dict['top'], w=box_dict['width'], h=box_dict['height'], - page=box_dict['page']) + return Box( + l=box_dict["left"], + t=box_dict["top"], + w=box_dict["width"], + h=box_dict["height"], + page=box_dict["page"], + ) @classmethod def from_coordinates(cls, x1: float, y1: float, x2: float, y2: float, page: int): @@ -74,10 +99,6 @@ def from_pdf_coordinates( @classmethod def small_boxes_to_big_box(cls, boxes: List["Box"]) -> "Box": """Computes one big box that tightly encapsulates all smaller input boxes""" - boxes = [box for box in boxes if box is not None] - if not boxes: - return None - if len({box.page for box in boxes}) != 1: raise ValueError(f"Bboxes not all on same page: {boxes}") x1 = min([bbox.l for bbox in boxes]) @@ -95,11 +116,6 @@ def coordinates(self) -> Tuple[float, float, float, float]: def center(self) -> Tuple[float, float]: return self.l + self.w / 2, self.t + self.h / 2 - @property - def xywh(self) -> Tuple[float, float, float, float]: - """Return a tuple of the (left, top, width, height) format.""" - return self.l, self.t, self.w, self.h - def get_relative(self, page_width: float, page_height: float) -> "Box": """Get the relative coordinates of self based on page_width, page_height.""" return self.__class__( @@ -120,18 +136,109 @@ def get_absolute(self, page_width: int, page_height: int) -> "Box": page=self.page, ) - def is_overlap(self, other: "Box", x: float = 0.0, y: float = 0, center: bool = False) -> bool: + def is_overlap( + self, other: "Box", x: float = 0.0, y: float = 0, center: bool = False + ) -> bool: """ Whether self overlaps with the other Box object. x, y distances for padding center (bool) if True, only consider overlapping if this box's center is contained by other """ + if self.page != other.page: + return False + x11, y11, x12, y12 = self.coordinates x21, y21, x22, y22 = other.coordinates if center: center_x, center_y = self.center - res = is_overlap_1d(center_x, center_x, x21, x22, x) and is_overlap_1d(center_y, center_y, y21, y22, y) + res = is_overlap_1d(center_x, center_x, x21, x22, x) and is_overlap_1d( + center_y, center_y, y21, y22, y + ) else: - res = is_overlap_1d(x11, x12, x21, x22, x) and is_overlap_1d(y11, y12, y21, y22, y) + res = is_overlap_1d(x11, x12, x21, x22, x) and is_overlap_1d( + y11, y12, y21, y22, y + ) return res + + @classmethod + def cluster_boxes(cls, boxes: List["Box"]) -> List[List[int]]: + """ + Cluster boxes into groups based on any overlap. + """ + if not boxes: + return [] + + clusters: List[List[int]] = [[0]] + cluster_id_to_big_box: Dict[int, Box] = {0: boxes[0]} + for box_id in range(1, len(boxes)): + box = boxes[box_id] + + # check all the clusters to see if the box overlaps with any of them + is_overlap = False + for cluster_id, big_box in cluster_id_to_big_box.items(): + if box.is_overlap(big_box, x=0, y=0): + is_overlap = True + break + + # resolve + if is_overlap: + clusters[cluster_id].append(box_id) + cluster_id_to_big_box[cluster_id] = cls.small_boxes_to_big_box( + [box, big_box] + ) + else: + clusters.append([box_id]) + cluster_id_to_big_box[len(clusters) - 1] = box + + # sort clusters + for cluster in clusters: + cluster.sort() + clusters.sort(key=lambda x: x[0]) + + return clusters + + def shrink(self, delta: float, ignore: bool = True, clip: bool = True): + x1, y1, x2, y2 = self.coordinates + if x2 - x1 <= 2 * delta: + if ignore: + logging.warning(f"box's x-coords {self} shrink too much. Ignoring.") + else: + raise ValueError( + f"box's x-coords {self} shrink too much with delta={delta}." + ) + else: + if clip: + logging.warning( + f"box's x-coords {self} go beyond page boundary. Clipping..." + ) + x1 = min(x1 + delta, 1.0) + x2 = max(x2 - delta, 0.0) + else: + raise ValueError( + f"box's x-coordinates {self} go beyond page boundary. need clip." + ) + + if y2 - y1 <= 2 * delta: + if ignore: + logging.warning(f"box's y-coords {self} shrink too much. Ignoring.") + else: + raise ValueError( + f"box's y-coords {self} shrink too much with delta={delta}." + ) + else: + if clip: + logging.warning( + f"box's y-coords {self} go beyond page boundary. Clipping..." + ) + y1 = min(y1 + delta, 1.0) + y2 = max(y2 - delta, 0.0) + else: + raise ValueError( + f"box's y-coordinates {self} go beyond page boundary. need clip." + ) + + self.l = x1 + self.t = y1 + self.w = x2 - x1 + self.h = y2 - y1 diff --git a/src/mmda/types/indexers.py b/src/mmda/types/indexers.py index beb12b1f..828d7094 100644 --- a/src/mmda/types/indexers.py +++ b/src/mmda/types/indexers.py @@ -4,15 +4,16 @@ """ -from typing import List - from abc import abstractmethod +from collections import defaultdict from dataclasses import dataclass, field +from typing import List -from mmda.types.annotation import SpanGroup, Annotation -from ncls import NCLS import numpy as np import pandas as pd +from ncls import NCLS + +from mmda.types.annotation import Annotation, Box, BoxGroup, SpanGroup @dataclass @@ -56,7 +57,7 @@ def __init__(self, span_groups: List[SpanGroup]) -> None: self._index = NCLS( pd.Series(starts, dtype=np.int64), pd.Series(ends, dtype=np.int64), - pd.Series(ids, dtype=np.int64) + pd.Series(ids, dtype=np.int64), ) self._ensure_disjoint() @@ -68,15 +69,26 @@ def _ensure_disjoint(self) -> None: """ for span_group in self._sgs: for span in span_group.spans: - matches = [match for match in self._index.find_overlap(span.start, span.end)] - if len(matches) > 1: + match_ids = [ + matched_id + for _start, _end, matched_id in self._index.find_overlap( + span.start, span.end + ) + ] + if len(match_ids) > 1: + matches = [self._sgs[match_id].to_json() for match_id in match_ids] raise ValueError( - f"Detected overlap with existing SpanGroup(s) {matches} for {span_group}" + f"Detected overlap! While processing the Span {span} as part of query SpanGroup {span_group.to_json()}, we found that it overlaps with existing SpanGroup(s):\n" + + "\n".join( + [f"\t{i}\t{m} " for i, m in zip(match_ids, matches)] + ) ) def find(self, query: SpanGroup) -> List[SpanGroup]: if not isinstance(query, SpanGroup): - raise ValueError(f'SpanGroupIndexer only works with `query` that is SpanGroup type') + raise ValueError( + f"SpanGroupIndexer only works with `query` that is SpanGroup type" + ) if not query.spans: return [] @@ -84,7 +96,9 @@ def find(self, query: SpanGroup) -> List[SpanGroup]: matched_ids = set() for span in query.spans: - for _start, _end, matched_id in self._index.find_overlap(span.start, span.end): + for _start, _end, matched_id in self._index.find_overlap( + span.start, span.end + ): matched_ids.add(matched_id) matched_span_groups = [self._sgs[matched_id] for matched_id in matched_ids] @@ -95,3 +109,78 @@ def find(self, query: SpanGroup) -> List[SpanGroup]: return sorted(list(matched_span_groups)) +class BoxGroupIndexer(Indexer): + """ + Manages a data structure for locating overlapping BoxGroups. + Builds a static nested containment list from BoxGroups + and accepts other BoxGroups as search probes. + """ + + def __init__(self, box_groups: List[BoxGroup]) -> None: + self._bgs = box_groups + + self._box_id_to_box_group_id = {} + self._boxes = [] + box_id = 0 + for bg_id, bg in enumerate(box_groups): + for box in bg.boxes: + self._boxes.append(box) + self._box_id_to_box_group_id[box_id] = bg_id + box_id += 1 + + self._np_boxes_x1 = np.array([b.l for b in self._boxes]) + self._np_boxes_y1 = np.array([b.t for b in self._boxes]) + self._np_boxes_x2 = np.array([b.l + b.w for b in self._boxes]) + self._np_boxes_y2 = np.array([b.t + b.h for b in self._boxes]) + self._np_boxes_page = np.array([b.page for b in self._boxes]) + + self._ensure_disjoint() + + def _find_overlap_boxes(self, query: Box) -> List[int]: + x1, y1, x2, y2 = query.coordinates + mask = ( + (self._np_boxes_x1 <= x2) + & (self._np_boxes_x2 >= x1) + & (self._np_boxes_y1 <= y2) + & (self._np_boxes_y2 >= y1) + & (self._np_boxes_page == query.page) + ) + return np.where(mask)[0].tolist() + + def _find_overlap_box_groups(self, query: Box) -> List[int]: + return [ + self._box_id_to_box_group_id[box_id] + for box_id in self._find_overlap_boxes(query) + ] + + def _ensure_disjoint(self) -> None: + """ + Constituent box groups must be fully disjoint. + Ensure the integrity of the built index. + """ + for box_group in self._bgs: + for box in box_group.boxes: + match_ids = self._find_overlap_box_groups(query=box) + if len(match_ids) > 1: + matches = [self._bgs[match_id].to_json() for match_id in match_ids] + raise ValueError( + f"Detected overlap! While processing the Box {box} as part of query BoxGroup {box_group.to_json()}, we found that it overlaps with existing BoxGroup(s):\n" + + "\n".join( + [f"\t{i}\t{m} " for i, m in zip(match_ids, matches)] + ) + ) + + def find(self, query: BoxGroup) -> List[BoxGroup]: + if not isinstance(query, BoxGroup): + raise ValueError( + f"BoxGroupIndexer only works with `query` that is BoxGroup type" + ) + + if not query.boxes: + return [] + + match_ids = [] + for box in query.boxes: + match_ids.extend(self._find_overlap_box_groups(query=box)) + + return [self._bgs[match_id] for match_id in sorted(set(match_ids))] diff --git a/src/mmda/types/span.py b/src/mmda/types/span.py index 2ce68feb..53cd931a 100644 --- a/src/mmda/types/span.py +++ b/src/mmda/types/span.py @@ -38,9 +38,10 @@ def __lt__(self, other: "Span"): return self.start < other.start @classmethod - def small_spans_to_big_span(cls, spans: List["Span"]) -> "Span": - # TODO: add warning for unsorted spans or not-contiguous spans - # TODO: what happens when Boxes cant be merged? + def small_spans_to_big_span( + cls, spans: List["Span"], merge_boxes: bool = True + ) -> "Span": + # TODO: add warning for non-contiguous spans? start = spans[0].start end = spans[0].end for span in spans[1:]: @@ -48,12 +49,54 @@ def small_spans_to_big_span(cls, spans: List["Span"]) -> "Span": start = span.start if span.end > end: end = span.end + if merge_boxes and all(span.box for span in spans): + new_box = Box.small_boxes_to_big_box(boxes=[span.box for span in spans]) + else: + new_box = None return Span( start=start, end=end, - box=Box.small_boxes_to_big_box(boxes=[span.box for span in spans]), + box=new_box, ) + @classmethod + def cluster_spans(cls, spans: List["Span"]) -> List[List[int]]: + """ + Cluster spans into groups based on any overlap. + """ + if not spans: + return [] + + clusters: List[List[int]] = [[0]] + cluster_id_to_big_span: Dict[int, Span] = {0: spans[0]} + for span_id in range(1, len(spans)): + span = spans[span_id] + + # check all the clusters to see if the span overlaps with any of them + is_overlap = False + for cluster_id, big_span in cluster_id_to_big_span.items(): + if span.is_overlap(big_span): + is_overlap = True + break + + # resolve + if is_overlap: + clusters[cluster_id].append(span_id) + cluster_id_to_big_span[cluster_id] = cls.small_spans_to_big_span( + [span, big_span], + merge_boxes=False, + ) + else: + clusters.append([span_id]) + cluster_id_to_big_span[len(clusters) - 1] = span + + # sort clusters + for cluster in clusters: + cluster.sort() + clusters.sort(key=lambda x: x[0]) + + return clusters + def is_overlap(self, other: "Span") -> bool: is_self_before_other = self.start < other.end and self.end > other.start is_other_before_self = other.start < self.end and other.end > self.start diff --git a/src/mmda/utils/tools.py b/src/mmda/utils/tools.py index 9effac49..96394113 100644 --- a/src/mmda/utils/tools.py +++ b/src/mmda/utils/tools.py @@ -1,10 +1,10 @@ from __future__ import annotations +import itertools import logging from collections import defaultdict from itertools import groupby -import itertools -from typing import List, Dict, Tuple +from typing import Dict, List, Tuple import numpy as np @@ -14,7 +14,12 @@ def allocate_overlapping_tokens_for_box( - tokens: List[SpanGroup], box, token_box_in_box_group: bool = False, x: float = 0.0, y: float = 0.0, center: bool = False + tokens: List[SpanGroup], + box, + token_box_in_box_group: bool = False, + x: float = 0.0, + y: float = 0.0, + center: bool = False, ) -> Tuple[List[Span], List[Span]]: """Finds overlap of tokens for given box Args @@ -29,10 +34,14 @@ def allocate_overlapping_tokens_for_box( """ allocated_tokens, remaining_tokens = [], [] for token in tokens: - if token_box_in_box_group and token.box_group.boxes[0].is_overlap(other=box, x=x, y=y, center=center): + if token_box_in_box_group and token.box_group.boxes[0].is_overlap( + other=box, x=x, y=y, center=center + ): # The token "box" is stored within the SpanGroup's .box_group allocated_tokens.append(token) - elif token.spans[0].box is not None and token.spans[0].box.is_overlap(other=box, x=x, y=y, center=center): + elif token.spans[0].box is not None and token.spans[0].box.is_overlap( + other=box, x=x, y=y, center=center + ): # default to assuming the token "box" is stored in the SpanGroup .box allocated_tokens.append(token) else: @@ -41,7 +50,7 @@ def allocate_overlapping_tokens_for_box( def box_groups_to_span_groups( - box_groups: List[BoxGroup], doc: Document, pad_x: bool = False, center: bool = False + box_groups: List[BoxGroup], doc: Document, pad_x: bool = False, center: bool = False ) -> List[SpanGroup]: """Generate SpanGroups from BoxGroups. Args @@ -60,23 +69,22 @@ def box_groups_to_span_groups( token_box_in_box_group = None for box_id, box_group in enumerate(box_groups): - all_tokens_overlapping_box_group = [] for box in box_group.boxes: - # Caching the page tokens to avoid duplicated search if box.page not in all_page_tokens: - cur_page_tokens = all_page_tokens[box.page] = doc.pages[ - box.page - ].tokens + cur_page_tokens = all_page_tokens[box.page] = doc.pages[box.page].tokens if token_box_in_box_group is None: # Determine whether box is stored on token SpanGroup span.box or in the box_group token_box_in_box_group = all( [ ( - (hasattr(token.box_group, "boxes") and len(token.box_group.boxes) == 1) - and token.spans[0].box is None + ( + hasattr(token.box_group, "boxes") + and len(token.box_group.boxes) == 1 + ) + and token.spans[0].box is None ) for token in cur_page_tokens ] @@ -84,9 +92,16 @@ def box_groups_to_span_groups( # Determine average width of tokens on this page if we are going to pad x if pad_x: if token_box_in_box_group and box.page not in avg_token_widths: - avg_token_widths[box.page] = np.average([t.box_group.boxes[0].w for t in cur_page_tokens]) - elif not token_box_in_box_group and box.page not in avg_token_widths: - avg_token_widths[box.page] = np.average([t.spans[0].box.w for t in cur_page_tokens]) + avg_token_widths[box.page] = np.average( + [t.box_group.boxes[0].w for t in cur_page_tokens] + ) + elif ( + not token_box_in_box_group + and box.page not in avg_token_widths + ): + avg_token_widths[box.page] = np.average( + [t.spans[0].box.w for t in cur_page_tokens] + ) else: cur_page_tokens = all_page_tokens[box.page] @@ -99,7 +114,7 @@ def box_groups_to_span_groups( # optionally pad x a small amount so that extra narrow token boxes (when split at punctuation) are not missed x=avg_token_widths.get(box.page, 0.0) * 0.5 if pad_x else 0.0, y=0.0, - center=center + center=center, ) all_page_tokens[box.page] = remaining_tokens @@ -113,7 +128,8 @@ def box_groups_to_span_groups( else MergeSpans( list_of_spans=list( itertools.chain.from_iterable( - span_group.spans for span_group in all_tokens_overlapping_box_group + span_group.spans + for span_group in all_tokens_overlapping_box_group ) ), index_distance=1, @@ -131,9 +147,11 @@ def box_groups_to_span_groups( ) if not token_box_in_box_group: - logging.warning("tokens with box stored in SpanGroup span.box will be deprecated (that is, " - "future Spans wont contain box). Ensure Document is annotated with tokens " - "having box stored in SpanGroup box_group.boxes") + logging.warning( + "tokens with box stored in SpanGroup span.box will be deprecated (that is, " + "future Spans wont contain box). Ensure Document is annotated with tokens " + "having box stored in SpanGroup box_group.boxes" + ) del all_page_tokens @@ -150,6 +168,7 @@ def box_groups_to_span_groups( # ) return derived_span_groups + class MergeSpans: """ Given w=width and h=height merge neighboring spans which are w, h or less apart or by merging neighboring spans @@ -201,20 +220,24 @@ def build_graph_index_overlap(self): """ starts_matrix = np.full( (len(self.list_of_spans), len(self.list_of_spans)), - [span.start for span in self.list_of_spans] + [span.start for span in self.list_of_spans], ) ends_matrix = np.full( (len(self.list_of_spans), len(self.list_of_spans)), - [span.end for span in self.list_of_spans] + [span.end for span in self.list_of_spans], ) starts_minus_ends = np.abs(starts_matrix - ends_matrix.T) ends_minus_starts = np.abs(ends_matrix - starts_matrix.T) - are_neighboring_spans = np.minimum(starts_minus_ends, ends_minus_starts) <= self.index_distance - neighboring_spans = np.transpose(are_neighboring_spans.nonzero()) + are_neighboring_spans = ( + np.minimum(starts_minus_ends, ends_minus_starts) <= self.index_distance + ) + neighboring_spans = np.transpose(are_neighboring_spans.nonzero()) if len(neighboring_spans) > 0: - neighboring_spans_no_dupes = neighboring_spans[np.where(neighboring_spans[:,1] < neighboring_spans[:,0])] + neighboring_spans_no_dupes = neighboring_spans[ + np.where(neighboring_spans[:, 1] < neighboring_spans[:, 0]) + ] for j, i in neighboring_spans_no_dupes: span_i = self.list_of_spans[i] @@ -305,14 +328,8 @@ def build_merged_spans_from_connected_components(self, index): for span in page_spans: spans_by_page[pg].append(span) for page_spans in spans_by_page.values(): - merged_box = Box.small_boxes_to_big_box( - [span.box for span in page_spans] - ) - merged_spans.append( - Span( - start=min([span.start for span in page_spans]), - end=max([span.end for span in page_spans]), - box=merged_box, - ) + merged_span = Span.small_spans_to_big_span( + spans=page_spans, merge_boxes=True ) + merged_spans.append(merged_span) return merged_spans diff --git a/tests/test_internal_ai2/test_api.py b/tests/test_internal_ai2/test_api.py index 404f2940..37c2c9cb 100644 --- a/tests/test_internal_ai2/test_api.py +++ b/tests/test_internal_ai2/test_api.py @@ -20,104 +20,107 @@ class ClassificationSpanGroup(mmda_api.SpanGroup): class TestApi(unittest.TestCase): def test_vanilla_span_group(self) -> None: - sg_ann = mmda_ann.SpanGroup.from_json({ - 'spans': [{'start': 0, 'end': 1}], - 'id': 1, - 'metadata': {'text': 'hello', 'id': 999} # note id not used; it's just in metadata - }) + sg_ann = mmda_ann.SpanGroup.from_json( + { + "spans": [{"start": 0, "end": 1}], + "id": 1, + "metadata": { + "text": "hello", + "id": 999, + }, # note id not used; it's just in metadata + } + ) sg_api = mmda_api.SpanGroup.from_mmda(sg_ann) - self.assertEqual(sg_api.text, 'hello') + self.assertEqual(sg_api.text, "hello") self.assertEqual(sg_api.id, 1) self.assertEqual(sg_api.attributes.dict(), {}) def test_classification_span_group(self) -> None: - sg_ann = mmda_ann.SpanGroup.from_json({ - 'spans': [{'start': 0, 'end': 1}], - 'metadata': {'text': 'hello', 'id': 1} - }) + sg_ann = mmda_ann.SpanGroup.from_json( + {"spans": [{"start": 0, "end": 1}], "metadata": {"text": "hello", "id": 1}} + ) with self.assertRaises(ValidationError): # this should fail because metadata is missing label # and confidence ClassificationSpanGroup.from_mmda(sg_ann) - sg_ann.metadata.label = 'label' + sg_ann.metadata.label = "label" sg_ann.metadata.score = 0.5 sg_api = ClassificationSpanGroup.from_mmda(sg_ann) - self.assertEqual( - sg_api.attributes.dict(), {'label': 'label', 'score': 0.5} - ) + self.assertEqual(sg_api.attributes.dict(), {"label": "label", "score": 0.5}) # extra field should just get ignored - sg_ann.metadata.extra = 'extra' - self.assertEqual( - sg_api.attributes.dict(), {'label': 'label', 'score': 0.5} - ) + sg_ann.metadata.extra = "extra" + self.assertEqual(sg_api.attributes.dict(), {"label": "label", "score": 0.5}) with self.assertRaises(ValidationError): # this should fail bc score is not a float - sg_ann.metadata.score = 'not a float' + sg_ann.metadata.score = "not a float" ClassificationSpanGroup.from_mmda(sg_ann) def test_equivalence(self): - sg_ann = mmda_ann.SpanGroup.from_json({ - 'spans': [{'start': 0, 'end': 1}], - 'metadata': {'label': 'label', 'score': 0.5} - }) + sg_ann = mmda_ann.SpanGroup.from_json( + { + "spans": [{"start": 0, "end": 1}], + "metadata": {"label": "label", "score": 0.5}, + } + ) sg_ann_2 = ClassificationSpanGroup.from_mmda(sg_ann).to_mmda() self.assertDictEqual(sg_ann.to_json(), sg_ann_2.to_json()) self.assertDictEqual(sg_ann.__dict__, sg_ann_2.__dict__) - def test_box(self): box = mmda_api.Box(left=0.1, top=0.1, width=0.1, height=0.1, page=0) - assert box.to_mmda() == mmdaBox(l=0.1, t=0.1, w=0.1, h=0.1, page=0) + assert ( + box.to_mmda().coordinates + == mmdaBox(l=0.1, t=0.1, w=0.1, h=0.1, page=0).coordinates + ) assert mmda_api.Box.from_mmda(box.to_mmda()) == box def test_span(self): - span = mmda_api.Span(start=0, end=1, box=mmda_api.Box(left=0.1, top=0.1, width=0.1, height=0.1, page=0)) - assert span.to_mmda() == mmdaSpan(start=0, end=1, box=mmdaBox(l=0.1, t=0.1, w=0.1, h=0.1, page=0)) + span = mmda_api.Span( + start=0, + end=1, + box=mmda_api.Box(left=0.1, top=0.1, width=0.1, height=0.1, page=0), + ) + assert ( + span.to_mmda().to_json() + == mmdaSpan( + start=0, end=1, box=mmdaBox(l=0.1, t=0.1, w=0.1, h=0.1, page=0) + ).to_json() + ) def test_box_group(self): box_group = mmda_api.BoxGroup( - boxes=[ - mmda_api.Box(left=0.1, top=0.1, width=0.1, height=0.1, page=0) - ], + boxes=[mmda_api.Box(left=0.1, top=0.1, width=0.1, height=0.1, page=0)], id=0, - type='test', + type="test", # these attributes are going to be discarded because # BoxGroup is using the default Attributes class - attributes={'one': 'Test string'} + attributes={"one": "Test string"}, ) - self.assertEqual( - mmda_api.BoxGroup.from_mmda(box_group.to_mmda()), - box_group - ) + self.assertEqual(mmda_api.BoxGroup.from_mmda(box_group.to_mmda()), box_group) def test_span_group(self): box_group = mmda_api.BoxGroup( - boxes=[ - mmda_api.Box(left=0.1, top=0.1, width=0.1, height=0.1, page=0) - ], + boxes=[mmda_api.Box(left=0.1, top=0.1, width=0.1, height=0.1, page=0)], id=0, - type='test', - attributes={'one': 'Test string'} + type="test", + attributes={"one": "Test string"}, ) span_group = mmda_api.SpanGroup( spans=[], box_group=box_group, - attributes={'one': 'Test string'}, + attributes={"one": "Test string"}, id=0, - type='test', - text='this is a test' + type="test", + text="this is a test", ) - self.assertEqual( - mmda_api.SpanGroup.from_mmda(span_group.to_mmda()), - span_group - ) + self.assertEqual(mmda_api.SpanGroup.from_mmda(span_group.to_mmda()), span_group) diff --git a/tests/test_parsers/test_pdf_plumber_parser.py b/tests/test_parsers/test_pdf_plumber_parser.py index 6ea06117..ce9a7280 100644 --- a/tests/test_parsers/test_pdf_plumber_parser.py +++ b/tests/test_parsers/test_pdf_plumber_parser.py @@ -19,7 +19,7 @@ class TestPDFPlumberParser(unittest.TestCase): def setUp(cls) -> None: cls.fixture_path = pathlib.Path(__file__).parent.parent / "fixtures" - ''' + """ def test_parse(self): parser = PDFPlumberParser() doc = parser.parse(input_pdf_path=self.fixture_path / "1903.10676.pdf") @@ -207,7 +207,7 @@ def test_convert_nested_text_to_doc_json(self): "abc\nd ef", "gh i\njkl", ] - ''' + """ def test_parser_stability(self): """ @@ -223,15 +223,25 @@ def test_parser_stability(self): parser = PDFPlumberParser() - current_doc = parser.parse(input_pdf_path=self.fixture_path / "4be952924cd565488b4a239dc6549095029ee578.pdf") + current_doc = parser.parse( + input_pdf_path=self.fixture_path + / "4be952924cd565488b4a239dc6549095029ee578.pdf" + ) - with open(self.fixture_path / "4be952924cd565488b4a239dc6549095029ee578__pdfplumber_doc.json", "r") as f: + with open( + self.fixture_path + / "4be952924cd565488b4a239dc6549095029ee578__pdfplumber_doc.json", + "r", + ) as f: raw_json = f.read() fixture_doc_json = json.loads(raw_json) fixture_doc = Document.from_json(fixture_doc_json) - - self.assertEqual(current_doc.symbols, fixture_doc.symbols, msg="Current parse has extracted different text from pdf.") + self.assertEqual( + current_doc.symbols, + fixture_doc.symbols, + msg="Current parse has extracted different text from pdf.", + ) def compare_span_groups(current_doc_sgs, fixture_doc_sgs, annotation_name): current_doc_sgs_simplified = [ @@ -244,17 +254,23 @@ def compare_span_groups(current_doc_sgs, fixture_doc_sgs, annotation_name): self.assertEqual( current_doc_sgs_simplified, fixture_doc_sgs_simplified, - msg=f"Current parse produces different SpanGroups for `{annotation_name}`" + msg=f"Current parse produces different SpanGroups for `{annotation_name}`", ) - current_doc_sg_boxes = [[list(s.box.xywh) + [s.box.page] for s in sg] for sg in current_doc_sgs] - fixture_doc_sg_boxes = [[list(s.box.xywh) + [s.box.page] for s in sg] for sg in current_doc_sgs] + current_doc_sg_boxes = [ + [list(s.box.coordinates) + [s.box.page] for s in sg] + for sg in current_doc_sgs + ] + fixture_doc_sg_boxes = [ + [list(s.box.coordinates) + [s.box.page] for s in sg] + for sg in current_doc_sgs + ] self.assertAlmostEqual( current_doc_sg_boxes, fixture_doc_sg_boxes, places=3, - msg=f"Boxes generated for `{annotation_name}` have changed." + msg=f"Boxes generated for `{annotation_name}` have changed.", ) compare_span_groups(current_doc.tokens, fixture_doc.tokens, "tokens") diff --git a/tests/test_predictors/test_svm_word_predictor.py b/tests/test_predictors/test_svm_word_predictor.py index 308229d7..f4a179f8 100644 --- a/tests/test_predictors/test_svm_word_predictor.py +++ b/tests/test_predictors/test_svm_word_predictor.py @@ -268,7 +268,7 @@ def test_cluster_tokens_by_whitespace(self): else: tokens = [self.doc.tokens[token_id] for token_id in token_ids] spans = [span for token in tokens for span in token.spans] - big_span = Span.small_spans_to_big_span(spans=spans) + big_span = Span.small_spans_to_big_span(spans=spans, merge_boxes=False) self.assertEqual( self.doc.symbols[big_span.start : big_span.end], "".join([token.text for token in tokens]), diff --git a/tests/test_recipes/core_recipe_fixtures.py b/tests/test_recipes/core_recipe_fixtures.py index 0f676847..fd0015fe 100644 --- a/tests/test_recipes/core_recipe_fixtures.py +++ b/tests/test_recipes/core_recipe_fixtures.py @@ -439,13 +439,6 @@ { "start": 3370, "end": 3372, - "box": { - "left": 0.7474297722689077, - "top": 0.7828144700712589, - "width": 0.017234544537815144, - "height": 0.012956175771971501, - "page": 0, - }, } ], "id": 895, @@ -456,13 +449,6 @@ { "start": 3373, "end": 3382, - "box": { - "left": 0.7730432389915969, - "top": 0.7828144700712589, - "width": 0.0749152648739495, - "height": 0.012956175771971501, - "page": 0, - }, } ], "id": 896, @@ -473,13 +459,6 @@ { "start": 3383, "end": 3394, - "box": { - "left": 0.5165052100840336, - "top": 0.7828144700712589, - "width": 0.366509090756303, - "height": 0.029060926365795714, - "page": 0, - }, } ], "id": 897, @@ -490,13 +469,6 @@ { "start": 3395, "end": 3405, - "box": { - "left": 0.5679338243697479, - "top": 0.798919220665083, - "width": 0.07995728588235285, - "height": 0.012956175771971612, - "page": 0, - }, } ], "id": 898, @@ -507,13 +479,6 @@ { "start": 3406, "end": 3408, - "box": { - "left": 0.6542532240336135, - "top": 0.798919220665083, - "width": 0.018242948739495835, - "height": 0.012956175771971612, - "page": 0, - }, } ], "id": 899, diff --git a/tests/test_types/test_annotation.py b/tests/test_types/test_annotation.py index 83e345fb..331249ef 100644 --- a/tests/test_types/test_annotation.py +++ b/tests/test_types/test_annotation.py @@ -1,36 +1,37 @@ +import unittest + from mmda.types.annotation import BoxGroup from mmda.types.box import Box -import unittest class TestBoxGroup(unittest.TestCase): def setUp(cls) -> None: - cls.box_group_json = {'boxes': [{'left': 0.1, - 'top': 0.6, - 'width': 0.36, - 'height': 0.221, - 'page': 0}], - 'id': None, - 'type': 'Text'} + cls.box_group_json = { + "boxes": [ + {"left": 0.1, "top": 0.6, "width": 0.36, "height": 0.221, "page": 0} + ], + "id": None, + "type": "Text", + } def test_from_json(self): self.assertIsInstance(BoxGroup.from_json(self.box_group_json), BoxGroup) - self.assertEqual(BoxGroup.from_json(self.box_group_json).boxes, - [Box(l=0.1, t=0.6, w=0.36, h=0.221, page=0)]) + self.assertEqual( + BoxGroup.from_json(self.box_group_json).boxes[0].to_json(), + Box(l=0.1, t=0.6, w=0.36, h=0.221, page=0).to_json(), + ) self.assertEqual(BoxGroup.from_json(self.box_group_json).id, None) - self.assertEqual(BoxGroup.from_json(self.box_group_json).type, 'Text') + self.assertEqual(BoxGroup.from_json(self.box_group_json).type, "Text") def test_to_json(self): boxgroup = BoxGroup.from_json(self.box_group_json) self.assertIsInstance(boxgroup.to_json(), dict) - self.assertEqual(boxgroup.to_json()['boxes'], - [{'left': 0.1, - 'top': 0.6, - 'width': 0.36, - 'height': 0.221, - 'page': 0}]) - - assert 'boxes' in boxgroup.to_json() - assert 'metadata' in boxgroup.to_json() + self.assertEqual( + boxgroup.to_json()["boxes"], + [{"left": 0.1, "top": 0.6, "width": 0.36, "height": 0.221, "page": 0}], + ) + + assert "boxes" in boxgroup.to_json() + assert "metadata" in boxgroup.to_json() diff --git a/tests/test_types/test_box.py b/tests/test_types/test_box.py index 8528ecc6..ea3194f8 100644 --- a/tests/test_types/test_box.py +++ b/tests/test_types/test_box.py @@ -1,18 +1,113 @@ import unittest + from mmda.types import box as mmda_box class TestBox(unittest.TestCase): def setUp(cls) -> None: - cls.box_dict = {'left': 0.2, - 'top': 0.09, - 'width': 0.095, - 'height': 0.017, - 'page': 0} + cls.box_dict = { + "left": 0.2, + "top": 0.09, + "width": 0.095, + "height": 0.017, + "page": 0, + } cls.box = mmda_box.Box(l=0.2, t=0.09, w=0.095, h=0.017, page=0) - def test_from_json(self): - self.assertEqual(self.box.from_json(self.box_dict), self.box) + def test_to_from_json(self): + box = self.box.from_json(self.box_dict) + self.assertDictEqual(box.to_json(), self.box_dict) + + def test_cluster_boxes(self): + # overlapping boxes + boxes = [ + mmda_box.Box(l=0.2, t=0.09, w=0.095, h=0.017, page=0), + mmda_box.Box(l=0.2, t=0.09, w=0.095, h=0.017, page=0), + mmda_box.Box(l=0.2, t=0.09, w=0.095, h=0.017, page=0), + ] + self.assertListEqual(mmda_box.Box.cluster_boxes(boxes), [[0, 1, 2]]) + + # on-overlapping boxes + boxes = [ + mmda_box.Box(l=0.2, t=0.09, w=0.095, h=0.017, page=0), + mmda_box.Box(l=0.3, t=0.20, w=0.095, h=0.017, page=0), + mmda_box.Box(l=0.4, t=0.30, w=0.095, h=0.017, page=0), + ] + self.assertListEqual(mmda_box.Box.cluster_boxes(boxes), [[0], [1], [2]]) + + # partially overlapping boxes + boxes = [ + mmda_box.Box(l=0.2, t=0.09, w=0.095, h=0.017, page=0), + mmda_box.Box(l=0.3, t=0.20, w=0.095, h=0.017, page=0), + mmda_box.Box(l=0.301, t=0.201, w=0.095, h=0.017, page=0), + ] + self.assertListEqual(mmda_box.Box.cluster_boxes(boxes), [[0], [1, 2]]) + + def test_create_invalid_box(self): + # relative coordinate box + box = mmda_box.Box(l=0.7, t=0.4, w=0.3, h=0.6, page=0) + # absolute coordinate box (larger than 1.0) + box = mmda_box.Box(l=0.7 + 0.0000001, t=0.2, w=0.3, h=0.4, page=0) + box = mmda_box.Box(l=0.7, t=0.4, w=0.3, h=0.6 + 0.000001, page=0) + # negative page + with self.assertRaises(ValueError): + box = mmda_box.Box(l=0.7, t=0.4, w=0.3, h=0.6, page=-1) + # negative coordinates + with self.assertRaises(ValueError): + box = mmda_box.Box(l=-0.1, t=0.4, w=0.3, h=0.6, page=0) + with self.assertRaises(ValueError): + box = mmda_box.Box(l=0.7, t=-0.1, w=0.3, h=0.6, page=0) + with self.assertRaises(ValueError): + box = mmda_box.Box(l=0.7, t=0.4, w=-0.1, h=0.6, page=0) + with self.assertRaises(ValueError): + box = mmda_box.Box(l=0.7, t=0.4, w=0.3, h=-0.1, page=0) + + def test_shrink(self): + # usual + box = mmda_box.Box(l=0.1, t=0.2, w=0.3, h=0.4, page=0) + box.shrink(delta=0.1) + self.assertAlmostEqual(box.l, 0.2) # 0.1 + 0.1 + self.assertAlmostEqual(box.t, 0.3) # 0.2 + 0.1 + self.assertAlmostEqual(box.w, 0.1) # 0.3 - 0.1 * 2 + self.assertAlmostEqual(box.h, 0.2) # 0.4 - 0.1 * 2 + + # shrinking until inverts box. would ignore shrinking along appropriate axis. + box = mmda_box.Box(l=0.9, t=0.5, w=0.1, h=0.3, page=0) + box.shrink(delta=0.1, ignore=True) + self.assertAlmostEqual(box.l, 0.9) # ignored + self.assertAlmostEqual(box.t, 0.6) # adjusted; 0.5 + 0.1 + self.assertAlmostEqual(box.w, 0.1) # ignored + self.assertAlmostEqual(box.h, 0.1) # adjusted; 0.3 - 2 * 0.1 + + # shrinking until out of bounds. would clip along appropriate axis. + # actually... does this ever happen unless Box is already out of bounds? - def test_to_json(self): - self.assertEqual(self.box.to_json(), self.box_dict) + def test_cluster_boxes_hard(self): + # from 4be952924cd565488b4a239dc6549095029ee578.pdf, page 2, tokens 650:655 + # boxes = [ + # mmda_box.Box( + # l=0.7761069934640523, + # t=0.14276190217171716, + # w=0.005533858823529373, + # h=0.008037272727272593, + # page=2, + # ), + # mmda_box.Box( + # l=0.7836408522875816, + # t=0.14691867138383832, + # w=0.005239432156862763, + # h=0.005360666666666692, + # page=2, + # ), + # mmda_box.Box( + # l=1.001, t=0.3424244465151515, w=-0.002, h=0.008037272727272737, page=2 + # ), + # mmda_box.Box( + # l=1.001, t=0.3424244465151515, w=-0.002, h=0.008037272727272737, page=2 + # ), + # mmda_box.Box( + # l=1.0, t=0.32670311318181816, w=0.0, h=0.010037272727272737, page=2 + # ), + # ] + # TODO: unfinished test + pass diff --git a/tests/test_types/test_document.py b/tests/test_types/test_document.py index 355e4993..13256f94 100644 --- a/tests/test_types/test_document.py +++ b/tests/test_types/test_document.py @@ -1,13 +1,12 @@ import json -import unittest import os +import unittest +from ai2_internal import api from mmda.types.annotation import SpanGroup from mmda.types.document import Document from mmda.types.names import MetadataField, SymbolsField -from ai2_internal import api - def resolve(file: str) -> str: return os.path.join(os.path.dirname(__file__), "../fixtures/types", file) @@ -58,10 +57,15 @@ def test_annotate_box_groups_gets_text(self): spp_doc = Document.from_json(json.load(f)) with open(resolve("test_document_box_groups.json")) as f: - box_groups = [api.BoxGroup(**bg).to_mmda() for bg in json.load(f)["grobid_bibs_box_groups"]] + box_groups = [ + api.BoxGroup(**bg).to_mmda() + for bg in json.load(f)["grobid_bibs_box_groups"] + ] spp_doc.annotate(new_span_groups=box_groups) - assert spp_doc.new_span_groups[0].text.startswith("Gutman G, Rosenzweig D, Golan J") + assert spp_doc.new_span_groups[0].text.startswith( + "Gutman G, Rosenzweig D, Golan J" + ) # when token boxes are on spans plumber_doc = "c8b53e2d9cd247e2d42719e337bfb13784d22bd2.json" @@ -70,7 +74,10 @@ def test_annotate_box_groups_gets_text(self): doc = Document.from_json(json.load(f)) with open(resolve("test_document_box_groups.json")) as f: - box_groups = [api.BoxGroup(**bg).to_mmda() for bg in json.load(f)["grobid_bibs_box_groups"]] + box_groups = [ + api.BoxGroup(**bg).to_mmda() + for bg in json.load(f)["grobid_bibs_box_groups"] + ] doc.annotate(new_span_groups=box_groups) assert doc.new_span_groups[0].text.startswith("Gutman G, Rosenzweig D, Golan J") @@ -80,7 +87,10 @@ def test_annotate_box_groups_allocates_all_overlapping_tokens(self): a last-known-good fixture """ # basic doc annotated with pages and tokens, from pdfplumber parser split at punctuation - with open(resolve("20fdafb68d0e69d193527a9a1cbe64e7e69a3798__pdfplumber_doc.json"), "r") as f: + with open( + resolve("20fdafb68d0e69d193527a9a1cbe64e7e69a3798__pdfplumber_doc.json"), + "r", + ) as f: raw_json = f.read() fixture_doc_json = json.loads(raw_json) doc = Document.from_json(fixture_doc_json) @@ -88,9 +98,16 @@ def test_annotate_box_groups_allocates_all_overlapping_tokens(self): # spangroups derived from boxgroups of boxes drawn neatly around bib entries by calling `.annotate` on # list of BoxGroups fixture_span_groups = [] - with open(resolve("20fdafb68d0e69d193527a9a1cbe64e7e69a3798__bib_entry_span_groups_from_box_groups.json"), "r") as f: + with open( + resolve( + "20fdafb68d0e69d193527a9a1cbe64e7e69a3798__bib_entry_span_groups_from_box_groups.json" + ), + "r", + ) as f: raw_json = f.read() - fixture_bib_entries_json = json.loads(raw_json)["bib_entry_span_groups_from_box_groups"] + fixture_bib_entries_json = json.loads(raw_json)[ + "bib_entry_span_groups_from_box_groups" + ] # make box_groups to annotate from test fixture bib entry span groups, and save the for bib_entry in fixture_bib_entries_json: @@ -105,6 +122,8 @@ def test_annotate_box_groups_allocates_all_overlapping_tokens(self): doc.annotate(fixture_box_groups=fixture_box_groups) for sg1, sg2 in zip(fixture_span_groups, doc.fixture_box_groups): - assert sg1.spans == sg2.spans assert sg1.text == sg2.text - + for sg1_span, sg2_span in zip(sg1.spans, sg2.spans): + assert sg1_span.start == sg2_span.start + assert sg1_span.end == sg2_span.end + assert sg1_span.box.coordinates == sg2_span.box.coordinates diff --git a/tests/test_types/test_indexers.py b/tests/test_types/test_indexers.py index 05ab6885..2cb7b33c 100644 --- a/tests/test_types/test_indexers.py +++ b/tests/test_types/test_indexers.py @@ -1,19 +1,13 @@ import unittest -from mmda.types import SpanGroup, Span -from mmda.types.indexers import SpanGroupIndexer +from mmda.types import Box, BoxGroup, Span, SpanGroup +from mmda.types.indexers import BoxGroupIndexer, SpanGroupIndexer class TestSpanGroupIndexer(unittest.TestCase): def test_overlap_within_single_spangroup_fails_checks(self): span_groups = [ - SpanGroup( - id=1, - spans=[ - Span(0, 5), - Span(4, 7) - ] - ) + SpanGroup(id=1, spans=[Span(0, 5), Span(4, 7)], allow_overlap=True) ] with self.assertRaises(ValueError): @@ -21,17 +15,8 @@ def test_overlap_within_single_spangroup_fails_checks(self): def test_overlap_between_spangroups_fails_checks(self): span_groups = [ - SpanGroup( - id=1, - spans=[ - Span(0, 5), - Span(5, 8) - ] - ), - SpanGroup( - id=2, - spans=[Span(6, 10)] - ) + SpanGroup(id=1, spans=[Span(0, 5), Span(5, 8)]), + SpanGroup(id=2, spans=[Span(6, 10)]), ] with self.assertRaises(ValueError): @@ -39,21 +24,9 @@ def test_overlap_between_spangroups_fails_checks(self): def test_finds_matching_groups_in_doc_order(self): span_groups_to_index = [ - SpanGroup( - id=1, - spans=[ - Span(0, 5), - Span(5, 8) - ] - ), - SpanGroup( - id=2, - spans=[Span(9, 10)] - ), - SpanGroup( - id=3, - spans=[Span(100, 105)] - ) + SpanGroup(id=1, spans=[Span(0, 5), Span(5, 8)]), + SpanGroup(id=2, spans=[Span(9, 10)]), + SpanGroup(id=3, spans=[Span(100, 105)]), ] index = SpanGroupIndexer(span_groups_to_index) @@ -66,4 +39,67 @@ def test_finds_matching_groups_in_doc_order(self): self.assertEqual(matches, [span_groups_to_index[0], span_groups_to_index[1]]) +class TestBoxGroupIndexer(unittest.TestCase): + def test_overlap_within_single_boxgroup_fails_checks(self): + box_groups = [ + BoxGroup( + id=1, + boxes=[Box(0, 0, 5, 5, page=0), Box(4, 4, 7, 7, page=0)], + allow_overlap=True, + ) + ] + + with self.assertRaises(ValueError): + BoxGroupIndexer(box_groups) + + def test_overlap_between_boxgroups_fails_checks(self): + box_groups = [ + BoxGroup( + id=1, boxes=[Box(0, 0, 5, 5, page=0), Box(5.01, 5.01, 8, 8, page=0)] + ), + BoxGroup(id=2, boxes=[Box(6, 6, 10, 10, page=0)]), + ] + + with self.assertRaises(ValueError): + BoxGroupIndexer(box_groups) + def test_finds_matching_groups_in_doc_order(self): + box_groups_to_index = [ + BoxGroup(id=1, boxes=[Box(0, 0, 1, 1, page=0), Box(2, 2, 1, 1, page=0)]), + BoxGroup(id=2, boxes=[Box(4, 4, 1, 1, page=0)]), + BoxGroup(id=3, boxes=[Box(100, 100, 1, 1, page=0)]), + ] + + index = BoxGroupIndexer(box_groups_to_index) + + # should intersect 1 and 2 but not 3 + probe = BoxGroup(id=4, boxes=[Box(1, 1, 5, 5, page=0), Box(9, 9, 5, 5, page=0)]) + matches = index.find(probe) + + self.assertEqual(len(matches), 2) + self.assertEqual(matches, [box_groups_to_index[0], box_groups_to_index[1]]) + + def test_finds_matching_groups_accounts_for_pages(self): + box_groups_to_index = [ + BoxGroup(id=1, boxes=[Box(0, 0, 1, 1, page=0), Box(2, 2, 1, 1, page=1)]), + BoxGroup(id=2, boxes=[Box(4, 4, 1, 1, page=1)]), + BoxGroup(id=3, boxes=[Box(100, 100, 1, 1, page=0)]), + ] + + index = BoxGroupIndexer(box_groups_to_index) + + # shouldnt intersect any given page 0 + probe = BoxGroup(id=4, boxes=[Box(1, 1, 5, 5, page=0), Box(9, 9, 5, 5, page=0)]) + matches = index.find(probe) + + self.assertEqual(len(matches), 1) + self.assertEqual(matches, [box_groups_to_index[0]]) + + # shoudl intersect after switching to page 1 (and the page 2 box doesnt intersect) + probe = BoxGroup( + id=4, boxes=[Box(1, 1, 5, 5, page=1), Box(100, 100, 1, 1, page=2)] + ) + matches = index.find(probe) + + self.assertEqual(len(matches), 2) + self.assertEqual(matches, [box_groups_to_index[0], box_groups_to_index[1]]) diff --git a/tests/test_types/test_json_conversion.py b/tests/test_types/test_json_conversion.py index e7a5f27d..765aba4d 100644 --- a/tests/test_types/test_json_conversion.py +++ b/tests/test_types/test_json_conversion.py @@ -1,16 +1,15 @@ -''' +""" Description: Test whether all properties for an mmda doc are preserved when converting to json and back. Author: @soldni -''' +""" import json from pathlib import Path -from mmda.types import BoxGroup, SpanGroup, Document, Metadata from mmda.parsers import PDFPlumberParser - +from mmda.types import BoxGroup, Document, Metadata, SpanGroup PDFFILEPATH = Path(__file__).parent / "../fixtures/1903.10676.pdf" @@ -45,10 +44,7 @@ def test_doc_conversion(): for field_name in orig_doc.fields: # this iterates over all span group for this field in both docs - field_it = zip( - getattr(orig_doc, field_name), - getattr(new_doc, field_name) - ) + field_it = zip(getattr(orig_doc, field_name), getattr(new_doc, field_name)) # type annotations to keep mypy quiet orig_sg: SpanGroup @@ -58,4 +54,7 @@ def test_doc_conversion(): # for each pair, they should have same metadata (type, id, # and optionally, text) and same spans. assert orig_sg.metadata == new_sg.metadata - assert orig_sg.spans == new_sg.spans + for orig_span, new_span in zip(orig_sg.spans, new_sg.spans): + assert orig_span.start == new_span.start + assert orig_span.end == new_span.end + assert orig_span.box.coordinates == new_span.box.coordinates diff --git a/tests/test_types/test_span.py b/tests/test_types/test_span.py index 53466097..1518a114 100644 --- a/tests/test_types/test_span.py +++ b/tests/test_types/test_span.py @@ -5,9 +5,9 @@ class TestSpan(unittest.TestCase): - def setUp(cls): - cls.span = mmda_span.Span(start=0, end=0) - cls.span_dict = { + def test_to_from_json(self): + span = mmda_span.Span(start=0, end=0) + span_dict = { "start": 0, "end": 8, "box": { @@ -18,21 +18,90 @@ def setUp(cls): "page": 0, }, } - - def test_from_json(self): self.assertEqual( - self.span.from_json(self.span_dict), + span.from_json(span_dict).to_json(), mmda_span.Span( start=0, end=8, box=mmda_box.Box(l=0.2, t=0.09, w=0.095, h=0.017, page=0), - ), + ).to_json(), ) - def test_to_json(self): - self.assertEqual(self.span.from_json(self.span_dict).to_json(), self.span_dict) - def test_is_overlap(self): + span1 = mmda_span.Span(start=0, end=8) + span2 = mmda_span.Span(start=0, end=8) + self.assertTrue(span1.is_overlap(span2)) + + span3 = mmda_span.Span(start=2, end=5) + self.assertTrue(span1.is_overlap(span3)) + + span4 = mmda_span.Span(start=8, end=10) + self.assertFalse(span1.is_overlap(span4)) + + span5 = mmda_span.Span(start=10, end=12) + self.assertFalse(span1.is_overlap(span5)) + + def test_small_spans_to_big_span(self): + spans = [ + mmda_span.Span(start=0, end=8), + mmda_span.Span(start=8, end=16), + mmda_span.Span(start=16, end=24), + ] + self.assertEqual( + mmda_span.Span.small_spans_to_big_span(spans=spans, merge_boxes=False), + mmda_span.Span(start=0, end=24), + ) + # if no boxes, should still work + self.assertEqual( + mmda_span.Span.small_spans_to_big_span(spans=spans, merge_boxes=True), + mmda_span.Span(start=0, end=24), + ) + + def test_small_spans_to_big_span_unsorted(self): + spans = [ + mmda_span.Span(start=8, end=16), + mmda_span.Span(start=0, end=8), + mmda_span.Span(start=16, end=24), + ] + self.assertEqual( + mmda_span.Span.small_spans_to_big_span(spans=spans), + mmda_span.Span(start=0, end=24), + ) + + spans = [ + mmda_span.Span(start=16, end=24), + mmda_span.Span(start=8, end=16), + mmda_span.Span(start=0, end=8), + ] + self.assertEqual( + mmda_span.Span.small_spans_to_big_span(spans=spans), + mmda_span.Span(start=0, end=24), + ) + + def test_cluster_spans(self): + # overlapping spans + spans = [ + mmda_span.Span(start=0, end=8), + mmda_span.Span(start=0, end=8), + mmda_span.Span(start=0, end=8), + ] + self.assertListEqual(mmda_span.Span.cluster_spans(spans=spans), [[0, 1, 2]]) + + # non-overlapping spans + spans = [ + mmda_span.Span(start=0, end=8), + mmda_span.Span(start=8, end=16), + mmda_span.Span(start=16, end=24), + ] + self.assertListEqual(mmda_span.Span.cluster_spans(spans), [[0], [1], [2]]) + + # partially overlapping spans + spans = [ + mmda_span.Span(start=0, end=8), + mmda_span.Span(start=9, end=16), + mmda_span.Span(start=10, end=15), + ] + self.assertListEqual(mmda_span.Span.cluster_spans(spans), [[0], [1, 2]]) span = mmda_span.Span(start=0, end=2) self.assertTrue(span.is_overlap(mmda_span.Span(start=0, end=1))) self.assertTrue(span.is_overlap(mmda_span.Span(start=1, end=2))) diff --git a/tests/test_utils/test_tools.py b/tests/test_utils/test_tools.py index 6b39bdb1..a35645a7 100644 --- a/tests/test_utils/test_tools.py +++ b/tests/test_utils/test_tools.py @@ -10,18 +10,15 @@ import unittest from mmda.types.annotation import BoxGroup, SpanGroup -from mmda.types.span import Span from mmda.types.box import Box from mmda.types.document import Document - -from mmda.utils.tools import MergeSpans -from mmda.utils.tools import box_groups_to_span_groups +from mmda.types.span import Span +from mmda.utils.tools import MergeSpans, box_groups_to_span_groups fixture_path = pathlib.Path(__file__).parent.parent / "fixtures" / "utils" class TestMergeNeighborSpans(unittest.TestCase): - def test_merge_multiple_neighbor_spans(self): spans = [Span(start=0, end=10), Span(start=11, end=20), Span(start=21, end=30)] merge_spans = MergeSpans(list_of_spans=spans, index_distance=1) @@ -54,7 +51,9 @@ def test_different_index_distances(self): def test_zero_index_distance(self): spans = [Span(start=0, end=10), Span(start=10, end=20)] - out = MergeSpans(list_of_spans=spans, index_distance=0).merge_neighbor_spans_by_symbol_distance() + out = MergeSpans( + list_of_spans=spans, index_distance=0 + ).merge_neighbor_spans_by_symbol_distance() assert len(out) == 1 assert isinstance(out[0], Span) assert out[0].start == 0 @@ -64,7 +63,7 @@ def test_handling_of_boxes(self): spans = [ Span(start=0, end=10, box=Box(l=0, t=0, w=1, h=1, page=0)), Span(start=11, end=20, box=Box(l=1, t=1, w=2, h=2, page=0)), - Span(start=21, end=150, box=Box(l=2, t=2, w=3, h=3, page=1)) + Span(start=21, end=150, box=Box(l=2, t=2, w=3, h=3, page=1)), ] merge_spans = MergeSpans(list_of_spans=spans, index_distance=1) merge_spans.merge_neighbor_spans_by_symbol_distance() @@ -77,14 +76,14 @@ def test_handling_of_boxes(self): assert out[0].end == 20 assert out[1].start == 21 assert out[1].end == 150 - assert out[0].box == Box(l=0, t=0, w=3, h=3, page=0) + assert out[0].box.coordinates == Box(l=0, t=0, w=3, h=3, page=0).coordinates # unmerged spans from separate pages keep their original box - assert out[1].box == spans[-1].box + assert out[1].box.coordinates == spans[-1].box.coordinates spans = [ Span(start=0, end=10, box=Box(l=0, t=0, w=1, h=1, page=1)), Span(start=11, end=20, box=Box(l=1, t=1, w=2, h=2, page=1)), - Span(start=100, end=150, box=Box(l=2, t=2, w=3, h=3, page=1)) + Span(start=100, end=150, box=Box(l=2, t=2, w=3, h=3, page=1)), ] merge_spans = MergeSpans(list_of_spans=spans, index_distance=1) @@ -96,15 +95,15 @@ def test_handling_of_boxes(self): assert out[0].end == 20 assert out[1].start == 100 assert out[1].end == 150 - assert out[0].box == Box(l=0, t=0, w=3, h=3, page=1) + assert out[0].box.coordinates == Box(l=0, t=0, w=3, h=3, page=1).coordinates # unmerged spans that were too far apart in symbol distance keep their original box - assert out[1].box == spans[-1].box + assert out[1].box.coordinates == spans[-1].box.coordinates spans = [ Span(start=0, end=10, box=Box(l=0, t=0, w=1, h=1, page=0)), Span(start=11, end=20), Span(start=21, end=150), - Span(start=155, end=200) + Span(start=155, end=200), ] merge_spans = MergeSpans(list_of_spans=spans, index_distance=1) merge_spans.merge_neighbor_spans_by_symbol_distance() @@ -126,223 +125,721 @@ def test_handling_of_boxes(self): list_of_spans_to_merge = [ - Span(start=3944, end=3948, - box=Box(l=0.19238134915568578, t=0.22752901673615306, w=0.06941334053447479, h=0.029442207414270286, - page=4)), - Span(start=3949, end=3951, - box=Box(l=0.27220460878651254, t=0.22752901673615306, w=0.03468585042904468, h=0.029442207414270286, - page=4)), - Span(start=4060, end=4063, - box=Box(l=0.4204075769894973, t=0.34144142726484455, w=0.023417310961637895, h=0.014200429984914883, - page=4)), - Span(start=4072, end=4075, - box=Box(l=0.5182742633669088, t=0.34144142726484455, w=0.029000512031393755, h=0.014200429984914883, - page=4)), - Span(start=4076, end=4083, - box=Box(l=0.5522956396696659, t=0.34144142726484455, w=0.06440764687304719, h=0.014200429984914883, - page=4)), - Span(start=4119, end=4128, - box=Box(l=0.2686971421659869, t=0.36273518298114954, w=0.08479235581478171, h=0.014200429984914883, - page=4)), - Span(start=4134, end=4144, - box=Box(l=0.40387889180816966, t=0.36273518298114954, w=0.08368776567508182, h=0.014200429984914883, - page=4)), - Span(start=4145, end=4148, - box=Box(l=0.4943548659781345, t=0.36273518298114954, w=0.042396177907390975, h=0.014200429984914883, - page=4)), - Span(start=4149, end=4162, - box=Box(l=0.5435392523804085, t=0.36273518298114954, w=0.11491754144296094, h=0.014200429984914883, - page=4)), - Span(start=4166, end=4177, - box=Box(l=0.6876581404256177, t=0.36273518298114954, w=0.09146006356715199, h=0.014200429984914883, - page=4)), - Span(start=4419, end=4427, - box=Box(l=0.2686971421659869, t=0.4479113936500019, w=0.06846450520430858, h=0.014200429984914883, - page=4)), - Span(start=4497, end=4505, - box=Box(l=0.2686971421659869, t=0.46920514936630686, w=0.06846450520430858, h=0.014200429984914883, - page=4)), - Span(start=4517, end=4520, - box=Box(l=0.42195400318507725, t=0.46920514936630686, w=0.029000512031393755, h=0.014200429984914883, - page=4)), - Span(start=4574, end=4581, - box=Box(l=0.2686971421659869, t=0.49049890508261185, w=0.07810456460532592, h=0.014200429984914883, - page=4)), - Span(start=4582, end=4587, - box=Box(l=0.35061756361754887, t=0.49049890508261185, w=0.03904224057412029, h=0.014200429984914883, - page=4)), - Span(start=4588, end=4591, - box=Box(l=0.39347566103790516, t=0.49049890508261185, w=0.023417310961637943, h=0.014200429984914883, - page=4)), - Span(start=4592, end=4601, - box=Box(l=0.4207088288457791, t=0.49049890508261185, w=0.08254300862121101, h=0.014200429984914883, - page=4)), - Span(start=4602, end=4613, - box=Box(l=0.5070676943132262, t=0.49049890508261185, w=0.09481400090042272, h=0.014200429984914883, - page=4)),] - -list_of_spans_to_merge_2 = [Span(start=30113, end=30119, - box=Box(l=0.12095229775767885, t=0.3578497466414853, w=0.05243790645011725, - h=0.014200429984914883, page=19)), - Span(start=30120, end=30124, - box=Box(l=0.17929474059091924, t=0.3578497466414853, w=0.030687522426571887, - h=0.014200429984914883, - page=19)), - Span(start=30125, end=30129, - box=Box(l=0.21799556239458678, t=0.3578497466414853, w=0.04350076804709073, - h=0.014200429984914883, page=19)), - Span(start=30130, end=30135, - box=Box(l=0.26740086682480063, t=0.3578497466414853, w=0.050208642713631964, - h=0.014200429984914883, - page=19)), - Span(start=30136, end=30141, - box=Box(l=0.32351404592155575, t=0.3578497466414853, w=0.0446254416438761, - h=0.014200429984914883, page=19)), - Span(start=30142, end=30151, - box=Box(l=0.37404402394855496, t=0.3578497466414853, w=0.0769598075514552, - h=0.014200429984914883, page=19)), - Span(start=30152, end=30155, - box=Box(l=0.4569284513402187, t=0.3578497466414853, w=0.029000512031393852, - h=0.014200429984914883, page=19)), - Span(start=30156, end=30165, - box=Box(l=0.4918334997547357, t=0.3578497466414853, w=0.0792091547450259, - h=0.014200429984914883, page=19)), - Span(start=30166, end=30175, - box=Box(l=0.5769471908828846, t=0.3578497466414853, w=0.07175819216632291, - h=0.014200429984914883, page=19)), - Span(start=30176, end=30179, - box=Box(l=0.6576023545380633, t=0.3578497466414853, w=0.03122977576787907, - h=0.014200429984914883, page=19)), - Span(start=30180, end=30184, - box=Box(l=0.6947366666890655, t=0.3578497466414853, w=0.03904224057412024, - h=0.014200429984914883, page=19)), - Span(start=30185, end=30190, - box=Box(l=0.7396834436463088, t=0.3578497466414853, w=0.05020864271363187, - h=0.014200429984914883, page=19)), - Span(start=30191, end=30193, - box=Box(l=0.7957966227430638, t=0.3578497466414853, w=0.015624929612482252, - h=0.014200429984914883, page=19)), - Span(start=30194, end=30197, - box=Box(l=0.12095229775767885, t=0.37500875791374183, w=0.024541984558423317, - h=0.014200429984914883, - page=19)), - Span(start=30198, end=30207, - box=Box(l=0.1518205712980198, t=0.37500875791374183, w=0.07695980755145514, - h=0.014200429984914883, page=19)), - Span(start=30208, end=30210, - box=Box(l=0.2351066678313926, t=0.37500875791374183, w=0.013395665875996984, - h=0.014200429984914883, - page=19)), - Span(start=30211, end=30214, - box=Box(l=0.2548286226893072, t=0.37500875791374183, w=0.02231272082193805, - h=0.014200429984914883, page=19)), - Span(start=30215, end=30217, - box=Box(l=0.283467632493163, t=0.37500875791374183, w=0.015624929612482252, - h=0.014200429984914883, page=19)), - Span(start=30218, end=30221, - box=Box(l=0.3054188510875629, t=0.37500875791374183, w=0.024541984558423317, - h=0.014200429984914883, - page=19)), - Span(start=30222, end=30229, - box=Box(l=0.33628712462790383, t=0.37500875791374183, w=0.055570925755447906, - h=0.014200429984914883, - page=19)), - Span(start=30230, end=30235, - box=Box(l=0.3981843393652693, t=0.37500875791374183, w=0.04183384110899822, - h=0.014200429984914883, page=19)), - Span(start=30236, end=30240, - box=Box(l=0.44668588822663785, t=0.37500875791374183, w=0.03570838669793504, - h=0.014200429984914883, - page=19)), - Span(start=30241, end=30244, - box=Box(l=0.4887205639064905, t=0.37500875791374183, w=0.020083457085452783, - h=0.014200429984914883, - page=19)), - Span(start=30245, end=30255, - box=Box(l=0.5151303099738609, t=0.37500875791374183, w=0.08810612623388145, - h=0.014200429984914883, page=19)), - Span(start=30256, end=30259, - box=Box(l=0.6095627251896601, t=0.37500875791374183, w=0.022312720821938, - h=0.014200429984914883, page=19)), - Span(start=30260, end=30262, - box=Box(l=0.6382017349935157, t=0.37500875791374183, w=0.015624929612482252, - h=0.014200429984914883, - page=19)), - Span(start=30263, end=30268, - box=Box(l=0.6601529535879158, t=0.37500875791374183, w=0.03958449391542752, - h=0.014200429984914883, page=19)), - Span(start=30269, end=30273, - box=Box(l=0.7098795933314969, t=0.37500875791374183, w=0.035708386697935225, - h=0.014200429984914883, - page=19)), - Span(start=30274, end=30276, - box=Box(l=0.7519142690113497, t=0.37500875791374183, w=0.013395665875997033, - h=0.014200429984914883, - page=19)), - Span(start=30277, end=30278, - box=Box(l=0.7716362238692644, t=0.37500875791374183, w=0.008917054945941066, - h=0.014200429984914883, - page=19)), - Span(start=30279, end=30281, - box=Box(l=0.7868795677971232, t=0.37500875791374183, w=0.02454198455842322, - h=0.014200429984914883, page=19)), - Span(start=30282, end=30291, - box=Box(l=0.12095229775767885, t=0.3921677691859983, w=0.08031374488472577, - h=0.014200429984914883, page=19)), - Span(start=30292, end=30296, - box=Box(l=0.2062869069137678, t=0.3921677691859983, w=0.03904224057412024, - h=0.014200429984914883, page=19)), - Span(start=30297, end=30302, - box=Box(l=0.25035001175925126, t=0.3921677691859983, w=0.050208642713631964, - h=0.014200429984914883, - page=19)), - Span(start=30303, end=30311, - box=Box(l=0.30557951874424644, t=0.3921677691859983, w=0.08143841848151108, - h=0.014200429984914883, page=19)), - Span(start=30312, end=30314, - box=Box(l=0.3920388014971207, t=0.3921677691859983, w=0.016729519752182193, - h=0.014200429984914883, page=19)), - Span(start=30315, end=30321, - box=Box(l=0.4137891855206661, t=0.3921677691859983, w=0.0535625800469026, - h=0.014200429984914883, page=19)), - Span(start=30322, end=30328, - box=Box(l=0.47237262983893197, t=0.3921677691859983, w=0.05354249658981717, - h=0.014200429984914883, page=19)), - Span(start=30329, end=30333, - box=Box(l=0.5309359907001122, t=0.3921677691859983, w=0.03681297683763493, - h=0.014200429984914883, page=19)), - Span(start=30334, end=30336, - box=Box(l=0.5727698318091105, t=0.3921677691859983, w=0.01672951975218224, - h=0.014200429984914883, page=19)), - Span(start=30337, end=30344, - box=Box(l=0.5945202158326559, t=0.3921677691859983, w=0.060230287799273016, - h=0.014200429984914883, page=19)), - Span(start=30345, end=30348, - box=Box(l=0.6597713679032922, t=0.3921677691859983, w=0.029000512031393946, - h=0.014200429984914883, page=19)), - Span(start=30349, end=30359, - box=Box(l=0.6937927442060494, t=0.3921677691859983, w=0.07834556609035141, - h=0.014200429984914883, page=19))] + Span( + start=3944, + end=3948, + box=Box( + l=0.19238134915568578, + t=0.22752901673615306, + w=0.06941334053447479, + h=0.029442207414270286, + page=4, + ), + ), + Span( + start=3949, + end=3951, + box=Box( + l=0.27220460878651254, + t=0.22752901673615306, + w=0.03468585042904468, + h=0.029442207414270286, + page=4, + ), + ), + Span( + start=4060, + end=4063, + box=Box( + l=0.4204075769894973, + t=0.34144142726484455, + w=0.023417310961637895, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4072, + end=4075, + box=Box( + l=0.5182742633669088, + t=0.34144142726484455, + w=0.029000512031393755, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4076, + end=4083, + box=Box( + l=0.5522956396696659, + t=0.34144142726484455, + w=0.06440764687304719, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4119, + end=4128, + box=Box( + l=0.2686971421659869, + t=0.36273518298114954, + w=0.08479235581478171, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4134, + end=4144, + box=Box( + l=0.40387889180816966, + t=0.36273518298114954, + w=0.08368776567508182, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4145, + end=4148, + box=Box( + l=0.4943548659781345, + t=0.36273518298114954, + w=0.042396177907390975, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4149, + end=4162, + box=Box( + l=0.5435392523804085, + t=0.36273518298114954, + w=0.11491754144296094, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4166, + end=4177, + box=Box( + l=0.6876581404256177, + t=0.36273518298114954, + w=0.09146006356715199, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4419, + end=4427, + box=Box( + l=0.2686971421659869, + t=0.4479113936500019, + w=0.06846450520430858, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4497, + end=4505, + box=Box( + l=0.2686971421659869, + t=0.46920514936630686, + w=0.06846450520430858, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4517, + end=4520, + box=Box( + l=0.42195400318507725, + t=0.46920514936630686, + w=0.029000512031393755, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4574, + end=4581, + box=Box( + l=0.2686971421659869, + t=0.49049890508261185, + w=0.07810456460532592, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4582, + end=4587, + box=Box( + l=0.35061756361754887, + t=0.49049890508261185, + w=0.03904224057412029, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4588, + end=4591, + box=Box( + l=0.39347566103790516, + t=0.49049890508261185, + w=0.023417310961637943, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4592, + end=4601, + box=Box( + l=0.4207088288457791, + t=0.49049890508261185, + w=0.08254300862121101, + h=0.014200429984914883, + page=4, + ), + ), + Span( + start=4602, + end=4613, + box=Box( + l=0.5070676943132262, + t=0.49049890508261185, + w=0.09481400090042272, + h=0.014200429984914883, + page=4, + ), + ), +] + +list_of_spans_to_merge_2 = [ + Span( + start=30113, + end=30119, + box=Box( + l=0.12095229775767885, + t=0.3578497466414853, + w=0.05243790645011725, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30120, + end=30124, + box=Box( + l=0.17929474059091924, + t=0.3578497466414853, + w=0.030687522426571887, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30125, + end=30129, + box=Box( + l=0.21799556239458678, + t=0.3578497466414853, + w=0.04350076804709073, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30130, + end=30135, + box=Box( + l=0.26740086682480063, + t=0.3578497466414853, + w=0.050208642713631964, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30136, + end=30141, + box=Box( + l=0.32351404592155575, + t=0.3578497466414853, + w=0.0446254416438761, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30142, + end=30151, + box=Box( + l=0.37404402394855496, + t=0.3578497466414853, + w=0.0769598075514552, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30152, + end=30155, + box=Box( + l=0.4569284513402187, + t=0.3578497466414853, + w=0.029000512031393852, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30156, + end=30165, + box=Box( + l=0.4918334997547357, + t=0.3578497466414853, + w=0.0792091547450259, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30166, + end=30175, + box=Box( + l=0.5769471908828846, + t=0.3578497466414853, + w=0.07175819216632291, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30176, + end=30179, + box=Box( + l=0.6576023545380633, + t=0.3578497466414853, + w=0.03122977576787907, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30180, + end=30184, + box=Box( + l=0.6947366666890655, + t=0.3578497466414853, + w=0.03904224057412024, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30185, + end=30190, + box=Box( + l=0.7396834436463088, + t=0.3578497466414853, + w=0.05020864271363187, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30191, + end=30193, + box=Box( + l=0.7957966227430638, + t=0.3578497466414853, + w=0.015624929612482252, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30194, + end=30197, + box=Box( + l=0.12095229775767885, + t=0.37500875791374183, + w=0.024541984558423317, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30198, + end=30207, + box=Box( + l=0.1518205712980198, + t=0.37500875791374183, + w=0.07695980755145514, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30208, + end=30210, + box=Box( + l=0.2351066678313926, + t=0.37500875791374183, + w=0.013395665875996984, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30211, + end=30214, + box=Box( + l=0.2548286226893072, + t=0.37500875791374183, + w=0.02231272082193805, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30215, + end=30217, + box=Box( + l=0.283467632493163, + t=0.37500875791374183, + w=0.015624929612482252, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30218, + end=30221, + box=Box( + l=0.3054188510875629, + t=0.37500875791374183, + w=0.024541984558423317, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30222, + end=30229, + box=Box( + l=0.33628712462790383, + t=0.37500875791374183, + w=0.055570925755447906, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30230, + end=30235, + box=Box( + l=0.3981843393652693, + t=0.37500875791374183, + w=0.04183384110899822, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30236, + end=30240, + box=Box( + l=0.44668588822663785, + t=0.37500875791374183, + w=0.03570838669793504, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30241, + end=30244, + box=Box( + l=0.4887205639064905, + t=0.37500875791374183, + w=0.020083457085452783, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30245, + end=30255, + box=Box( + l=0.5151303099738609, + t=0.37500875791374183, + w=0.08810612623388145, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30256, + end=30259, + box=Box( + l=0.6095627251896601, + t=0.37500875791374183, + w=0.022312720821938, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30260, + end=30262, + box=Box( + l=0.6382017349935157, + t=0.37500875791374183, + w=0.015624929612482252, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30263, + end=30268, + box=Box( + l=0.6601529535879158, + t=0.37500875791374183, + w=0.03958449391542752, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30269, + end=30273, + box=Box( + l=0.7098795933314969, + t=0.37500875791374183, + w=0.035708386697935225, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30274, + end=30276, + box=Box( + l=0.7519142690113497, + t=0.37500875791374183, + w=0.013395665875997033, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30277, + end=30278, + box=Box( + l=0.7716362238692644, + t=0.37500875791374183, + w=0.008917054945941066, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30279, + end=30281, + box=Box( + l=0.7868795677971232, + t=0.37500875791374183, + w=0.02454198455842322, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30282, + end=30291, + box=Box( + l=0.12095229775767885, + t=0.3921677691859983, + w=0.08031374488472577, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30292, + end=30296, + box=Box( + l=0.2062869069137678, + t=0.3921677691859983, + w=0.03904224057412024, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30297, + end=30302, + box=Box( + l=0.25035001175925126, + t=0.3921677691859983, + w=0.050208642713631964, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30303, + end=30311, + box=Box( + l=0.30557951874424644, + t=0.3921677691859983, + w=0.08143841848151108, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30312, + end=30314, + box=Box( + l=0.3920388014971207, + t=0.3921677691859983, + w=0.016729519752182193, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30315, + end=30321, + box=Box( + l=0.4137891855206661, + t=0.3921677691859983, + w=0.0535625800469026, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30322, + end=30328, + box=Box( + l=0.47237262983893197, + t=0.3921677691859983, + w=0.05354249658981717, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30329, + end=30333, + box=Box( + l=0.5309359907001122, + t=0.3921677691859983, + w=0.03681297683763493, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30334, + end=30336, + box=Box( + l=0.5727698318091105, + t=0.3921677691859983, + w=0.01672951975218224, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30337, + end=30344, + box=Box( + l=0.5945202158326559, + t=0.3921677691859983, + w=0.060230287799273016, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30345, + end=30348, + box=Box( + l=0.6597713679032922, + t=0.3921677691859983, + w=0.029000512031393946, + h=0.014200429984914883, + page=19, + ), + ), + Span( + start=30349, + end=30359, + box=Box( + l=0.6937927442060494, + t=0.3921677691859983, + w=0.07834556609035141, + h=0.014200429984914883, + page=19, + ), + ), +] def test_merge_spans(): - assert len(list_of_spans_to_merge) == (len(MergeSpans(list_of_spans_to_merge, 0, 0) - .merge_neighbor_spans_by_box_coordinate())) + assert len(list_of_spans_to_merge) == ( + len( + MergeSpans( + list_of_spans_to_merge, 0, 0 + ).merge_neighbor_spans_by_box_coordinate() + ) + ) - assert 4 == len(MergeSpans(list_of_spans_to_merge, 0.04387334, 0.01421097).merge_neighbor_spans_by_box_coordinate()) + assert 4 == len( + MergeSpans( + list_of_spans_to_merge, 0.04387334, 0.01421097 + ).merge_neighbor_spans_by_box_coordinate() + ) merge_spans = MergeSpans(list_of_spans_to_merge_2, 0.04387334, 0.01421097) assert 1 == len(merge_spans.merge_neighbor_spans_by_box_coordinate()) - assert [30113, 30359] == [merge_spans.merge_neighbor_spans_by_box_coordinate()[0].start, merge_spans.merge_neighbor_spans_by_box_coordinate()[0].end] + assert [30113, 30359] == [ + merge_spans.merge_neighbor_spans_by_box_coordinate()[0].start, + merge_spans.merge_neighbor_spans_by_box_coordinate()[0].end, + ] def test_merge_neighbor_spans_by_symbol_distance(): - assert 7 == (len(MergeSpans(list_of_spans_to_merge, index_distance=10) - .merge_neighbor_spans_by_symbol_distance())) - + assert 7 == ( + len( + MergeSpans( + list_of_spans_to_merge, index_distance=10 + ).merge_neighbor_spans_by_symbol_distance() + ) + ) - assert 10 == len(MergeSpans(list_of_spans_to_merge, index_distance=1).merge_neighbor_spans_by_symbol_distance()) + assert 10 == len( + MergeSpans( + list_of_spans_to_merge, index_distance=1 + ).merge_neighbor_spans_by_symbol_distance() + ) list_of_spans_to_merge_2 = [ Span(start=1, end=3, box=Box(l=0.1, t=0.2, w=0.2, h=0.2, page=11)), @@ -360,7 +857,9 @@ def test_merge_neighbor_spans_by_symbol_distance(): assert 1 == len(result) assert set([(1, 7)]) == set([(entry.start, entry.end) for entry in result]) - assert [Box(l=0.1, t=0.2, w=0.4, h=0.2, page=11)] == [entry.box for entry in result] + assert [Box(l=0.1, t=0.2, w=0.4, h=0.2, page=11).coordinates] == [ + entry.box.coordinates for entry in result + ] def test_from_span_groups_with_box_groups(): @@ -370,30 +869,41 @@ def test_from_span_groups_with_box_groups(): list_of_spans_to_merge_in_span_group_format.append( SpanGroup( spans=[Span(start=span.start, end=span.end)], - box_group=BoxGroup(boxes=[span.box]) + box_group=BoxGroup(boxes=[span.box]), ) ) - assert 7 == (len(MergeSpans.from_span_groups_with_box_groups( - list_of_spans_to_merge_in_span_group_format, - index_distance=10).merge_neighbor_spans_by_symbol_distance()) - ) + assert 7 == ( + len( + MergeSpans.from_span_groups_with_box_groups( + list_of_spans_to_merge_in_span_group_format, index_distance=10 + ).merge_neighbor_spans_by_symbol_distance() + ) + ) - assert len(list_of_spans_to_merge) == (len(MergeSpans.from_span_groups_with_box_groups( - list_of_spans_to_merge_in_span_group_format, - 0, - 0).merge_neighbor_spans_by_box_coordinate())) + assert len(list_of_spans_to_merge) == ( + len( + MergeSpans.from_span_groups_with_box_groups( + list_of_spans_to_merge_in_span_group_format, 0, 0 + ).merge_neighbor_spans_by_box_coordinate() + ) + ) def test_box_groups_to_span_groups(): # basic doc annotated with pages and tokens, from pdfplumber parser split at punctuation - with open(fixture_path / "20fdafb68d0e69d193527a9a1cbe64e7e69a3798__pdfplumber_doc.json", "r") as f: + with open( + fixture_path / "20fdafb68d0e69d193527a9a1cbe64e7e69a3798__pdfplumber_doc.json", + "r", + ) as f: raw_json = f.read() fixture_doc_json = json.loads(raw_json) doc = Document.from_json(fixture_doc_json) # boxes drawn neatly around bib entries - with open(fixture_path / "20fdafb68d0e69d193527a9a1cbe64e7e69a3798__bib_entries.json", "r") as f: + with open( + fixture_path / "20fdafb68d0e69d193527a9a1cbe64e7e69a3798__bib_entries.json", "r" + ) as f: raw_json = f.read() fixture_bib_entries_json = json.loads(raw_json)["bib_entries"] @@ -404,15 +914,28 @@ def test_box_groups_to_span_groups(): # generate span_groups with different settings overlap_span_groups = box_groups_to_span_groups(box_groups, doc, center=False) - overlap_at_token_center_span_groups = box_groups_to_span_groups(box_groups, doc, center=True) - overlap_at_token_center_span_groups_x_padded = box_groups_to_span_groups(box_groups, doc, center=True, pad_x=True) - - assert (len(box_groups) == len(overlap_span_groups) == len(overlap_at_token_center_span_groups) == len(overlap_at_token_center_span_groups_x_padded)) + overlap_at_token_center_span_groups = box_groups_to_span_groups( + box_groups, doc, center=True + ) + overlap_at_token_center_span_groups_x_padded = box_groups_to_span_groups( + box_groups, doc, center=True, pad_x=True + ) + + assert ( + len(box_groups) + == len(overlap_span_groups) + == len(overlap_at_token_center_span_groups) + == len(overlap_at_token_center_span_groups_x_padded) + ) # annotate all onto doc to extract texts: doc.annotate(overlap_span_groups=overlap_span_groups) - doc.annotate(overlap_at_token_center_span_groups=overlap_at_token_center_span_groups) - doc.annotate(overlap_at_token_center_span_groups_x_padded=overlap_at_token_center_span_groups_x_padded) + doc.annotate( + overlap_at_token_center_span_groups=overlap_at_token_center_span_groups + ) + doc.annotate( + overlap_at_token_center_span_groups_x_padded=overlap_at_token_center_span_groups_x_padded + ) # when center=False, any token overlap with BoxGroup becomes part of the SpanGroup # in this example, tokens from bib entry '29 and '31' overlap with the box drawn neatly around '30' @@ -444,4 +967,6 @@ def test_box_groups_to_span_groups(): assert doc.overlap_at_token_center_span_groups_x_padded[6].text.startswith("[6]") # original box_group boxes are saved - assert all([sg.box_group is not None for sg in doc.overlap_at_token_center_span_groups]) + assert all( + [sg.box_group is not None for sg in doc.overlap_at_token_center_span_groups] + )