diff --git a/mmda/parsers/pdfplumber_parser.py b/mmda/parsers/pdfplumber_parser.py index 06d1f1ae..27c6312e 100644 --- a/mmda/parsers/pdfplumber_parser.py +++ b/mmda/parsers/pdfplumber_parser.py @@ -365,39 +365,3 @@ def _align_coarse_and_fine_tokens( return out - - - - - - - - -""" - - - - - - row_annos.append(row) - current_rows_tokens = [] - - # if new row... is it also a new page? - if next_page_id == current_page_id: - current_pages_tokens.append(token) - else: - page = SpanGroup( - spans=[ - Span( - start=current_pages_tokens[0][0].start, - end=current_pages_tokens[-1][0].end, - box=Box.small_boxes_to_big_box( - boxes=[span.box for t in current_pages_tokens for span in t] - ), - ) - ] - ) - page_annos.append(page) - current_pages_tokens = [] - -""" \ No newline at end of file diff --git a/mmda/types/__init__.py b/mmda/types/__init__.py index d0f3929c..24dcf0aa 100644 --- a/mmda/types/__init__.py +++ b/mmda/types/__init__.py @@ -1,5 +1,5 @@ from mmda.types.document import Document -from mmda.types.annotation import SpanGroup, BoxGroup +from mmda.types.annotation import SpanGroup, BoxGroup, Relation from mmda.types.span import Span from mmda.types.box import Box from mmda.types.image import PILImage @@ -12,5 +12,6 @@ 'Span', 'Box', 'PILImage', - 'Metadata' + 'Metadata', + "Relation" ] \ No newline at end of file diff --git a/mmda/types/annotation.py b/mmda/types/annotation.py index 4857df5c..5d925d83 100644 --- a/mmda/types/annotation.py +++ b/mmda/types/annotation.py @@ -5,7 +5,10 @@ Collections of Annotations are how one constructs a new Iterable of Group-type objects within the Document +@kylel, @lucas + """ +import logging import warnings from abc import abstractmethod from copy import deepcopy @@ -18,11 +21,9 @@ if TYPE_CHECKING: from mmda.types.document import Document - __all__ = ["Annotation", "BoxGroup", "SpanGroup", "Relation"] - def warn_deepcopy_of_annotation(obj: "Annotation") -> None: """Warns when a deepcopy is performed on an Annotation.""" @@ -34,6 +35,22 @@ def warn_deepcopy_of_annotation(obj: "Annotation") -> None: warnings.warn(msg, UserWarning, stacklevel=2) +class AnnotationName: + """Stores a name that uniquely identifies this Annotation within a Document""" + + def __init__(self, field: str, id: int): + self.field = field + self.id = id + + def __str__(self) -> str: + return f"{self.field}-{self.id}" + + @classmethod + def from_str(cls, s: str) -> 'AnnotationName': + field, id = s.split('-') + id = int(id) + return AnnotationName(field=field, id=id) + class Annotation: """Annotation is intended for storing model predictions for a document.""" @@ -42,40 +59,55 @@ def __init__( self, id: Optional[int] = None, doc: Optional['Document'] = None, + field: Optional[str] = None, metadata: Optional[Metadata] = None ): self.id = id self.doc = doc + self.field = field self.metadata = metadata if metadata else Metadata() @abstractmethod def to_json(self) -> Dict: - pass + raise NotImplementedError @classmethod @abstractmethod def from_json(cls, annotation_dict: Dict) -> "Annotation": - pass + raise NotImplementedError + + @property + def name(self) -> Optional[AnnotationName]: + if self.field and self.id: + return AnnotationName(field=self.field, id=self.id) + else: + return None - def attach_doc(self, doc: "Document") -> None: + def _attach_doc(self, doc: "Document", field: str) -> None: if not self.doc: self.doc = doc + self.field = field else: raise AttributeError("This annotation already has an attached document") - # TODO[kylel] - comment explaining - def __getattr__(self, field: str) -> List["Annotation"]: - if self.doc is None: - raise ValueError("This annotation is not attached to a document") + def _get_siblings(self) -> List['Annotation']: + """This method gets all other objects sharing the same field as the current object. + Only works after a Document has been attached, which is how objects learn their `field`.""" + if not self.doc: + raise AttributeError("This annotation does not have an attached document") + return self.doc.__getattr__(self.field) - if field in self.doc.fields: - return self.doc.find_overlapping(self, field) + def __getattr__(self, field: str) -> List["Annotation"]: + """This method allows jumping from an object of one field to all overlapping + objects of another field. For example `page.tokens` jumps from a particular page + to all its intersecting tokens.""" + if not self.doc: + raise AttributeError("This annotation does not have an attached document") if field in self.doc.fields: return self.doc.find_overlapping(self, field) - - return self.__getattribute__(field) - + else: + return [] class BoxGroup(Annotation): @@ -84,12 +116,14 @@ def __init__( boxes: List[Box], id: Optional[int] = None, doc: Optional['Document'] = None, + field: Optional[str] = None, metadata: Optional[Metadata] = None, ): self.boxes = boxes - super().__init__(id=id, doc=doc, metadata=metadata) + super().__init__(id=id, doc=doc, field=field, metadata=metadata) def to_json(self) -> Dict: + """Note: even if `doc` or `field` are attached, don't include in JSON to avoid bloat""" box_group_dict = dict( boxes=[box.to_json() for box in self.boxes], id=self.id, @@ -132,6 +166,7 @@ def __deepcopy__(self, memo): box_group = BoxGroup( boxes=deepcopy(self.boxes, memo), id=self.id, + field=self.field, metadata=deepcopy(self.metadata, memo) ) @@ -142,47 +177,38 @@ def __deepcopy__(self, memo): @property def type(self) -> str: + logging.warning(msg='`.type` to be deprecated in future versions. Use `.metadata.type`') return self.metadata.get("type", None) @type.setter def type(self, type: Union[str, None]) -> None: + logging.warning(msg='`.type` to be deprecated in future versions. Use `.metadata.type`') self.metadata.type = type class SpanGroup(Annotation): - def __init__( self, spans: List[Span], box_group: Optional[BoxGroup] = None, id: Optional[int] = None, doc: Optional['Document'] = None, + field: Optional[str] = None, metadata: Optional[Metadata] = None, ): self.spans = spans self.box_group = box_group - super().__init__(id=id, doc=doc, metadata=metadata) + super().__init__(id=id, doc=doc, field=field, metadata=metadata) @property def symbols(self) -> List[str]: if self.doc is not None: - return [ - self.doc.symbols[span.start: span.end] for span in self.spans - ] + return [self.doc.symbols[span.start: span.end] for span in self.spans] else: return [] - def annotate( - self, is_overwrite: bool = False, **kwargs: Iterable["Annotation"] - ) -> None: - if self.doc is None: - raise ValueError("SpanGroup has no attached document!") - - key_remaps = {k: v for k, v in kwargs.items()} - - self.doc.annotate(is_overwrite=is_overwrite, **key_remaps) - def to_json(self) -> Dict: + """Note: even if `doc` or `field` are attached, don't include in JSON to avoid bloat""" span_group_dict = dict( spans=[span.to_json() for span in self.spans], id=self.id, @@ -208,7 +234,7 @@ def from_json(cls, span_group_dict: Dict) -> "SpanGroup": else: # this fallback is necessary to ensure compatibility with span # groups that were create before the metadata migration and - # therefore have "id", "type" in the root of the json dict instead. + # therefore have "type" in the root of the json dict instead. metadata_dict = { "type": span_group_dict.get("type", None), "text": span_group_dict.get("text", None) @@ -255,6 +281,7 @@ def __deepcopy__(self, memo): span_group = SpanGroup( spans=deepcopy(self.spans, memo), id=self.id, + field=self.field, metadata=deepcopy(self.metadata, memo), box_group=deepcopy(self.box_group, memo) ) @@ -266,10 +293,12 @@ def __deepcopy__(self, memo): @property def type(self) -> str: + logging.warning(msg='`.type` to be deprecated in future versions. Use `.metadata.type`') return self.metadata.get("type", None) @type.setter def type(self, type: Union[str, None]) -> None: + logging.warning(msg='`.type` to be deprecated in future versions. Use `.metadata.type`') self.metadata.type = type @property @@ -284,6 +313,49 @@ def text(self, text: Union[str, None]) -> None: self.metadata.text = text - class Relation(Annotation): - pass \ No newline at end of file + def __init__( + self, + key: SpanGroup, + value: SpanGroup, + id: Optional[int] = None, + doc: Optional['Document'] = None, + field: Optional[str] = None, + metadata: Optional[Metadata] = None + ): + if key.name is None: + raise ValueError(f'Relation requires the key {key} to have a `.name`') + if value.name is None: + raise ValueError(f'Relation requires the value {value} to have a `.name`') + self.key = key + self.value = value + super().__init__(id=id, doc=doc, field=field, metadata=metadata) + + def to_json(self) -> Dict: + """Note: even if `doc` or `field` are attached, don't include in JSON to avoid bloat""" + relation_dict = dict( + key=str(self.key.name), + value=str(self.value.name), + id=self.id, + metadata=self.metadata.to_json() + ) + return { + key: value + for key, value in relation_dict.items() + if value is not None + } # only serialize non-null values + + @classmethod + def from_json( + cls, + relation_dict: Dict, + doc: 'Document', + ) -> "Relation": + key_name = AnnotationName.from_str(s=relation_dict['key']) + value_name = AnnotationName.from_str(s=relation_dict['value']) + return cls( + key=doc.locate_annotation(name=key_name), + value=doc.locate_annotation(name=value_name), + id=relation_dict.get("id", None), + metadata=Metadata.from_json(relation_dict.get('metadata', {})) + ) diff --git a/mmda/types/document.py b/mmda/types/document.py index cbd00655..b0d6f3da 100644 --- a/mmda/types/document.py +++ b/mmda/types/document.py @@ -9,7 +9,7 @@ from copy import deepcopy from typing import Dict, Iterable, List, Optional -from mmda.types.annotation import Annotation, BoxGroup, SpanGroup +from mmda.types.annotation import Annotation, BoxGroup, SpanGroup, AnnotationName from mmda.types.image import PILImage from mmda.types.indexers import Indexer, SpanGroupIndexer from mmda.types.names import Images, Symbols @@ -17,7 +17,6 @@ class Document: - SPECIAL_FIELDS = [Symbols, Images] UNALLOWED_FIELD_NAMES = ["fields"] @@ -32,35 +31,35 @@ def fields(self) -> List[str]: return self.__fields # TODO: extend implementation to support DocBoxGroup - def find_overlapping(self, query: Annotation, field_name: str) -> List[Annotation]: + def find_overlapping(self, query: Annotation, field: str) -> List[Annotation]: if not isinstance(query, SpanGroup): raise NotImplementedError( f"Currently only supports query of type SpanGroup" ) - return self.__indexers[field_name].find(query=query) + return self.__indexers[field].find(query=query) def annotate( - self, is_overwrite: bool = False, **kwargs: Iterable[Annotation] + self, is_overwrite: bool = False, **kwargs: Iterable[Annotation] ) -> None: """Annotate the fields for document symbols (correlating the annotations with the symbols) and store them into the papers. """ # 1) check validity of field names - for field_name in kwargs.keys(): + for field in kwargs.keys(): assert ( - field_name not in self.SPECIAL_FIELDS - ), f"The field_name {field_name} should not be in {self.SPECIAL_FIELDS}." + field not in self.SPECIAL_FIELDS + ), f"The field {field} should not be in {self.SPECIAL_FIELDS}." - if field_name in self.fields: + if field in self.fields: # already existing field, check if ok overriding if not is_overwrite: raise AssertionError( - f"This field name {field_name} already exists. To override, set `is_overwrite=True`" + f"This field name {field} already exists. To override, set `is_overwrite=True`" ) - elif field_name in dir(self): + elif field in dir(self): # not an existing field, but a reserved class method name raise AssertionError( - f"The field_name {field_name} should not conflict with existing class properties" + f"The field {field} should not conflict with existing class properties" ) # Kyle's preserved comment: @@ -68,39 +67,39 @@ def annotate( # overhead on large documents. # 2) register fields into Document - for field_name, annotations in kwargs.items(): + for field, annotations in kwargs.items(): if len(annotations) == 0: - warnings.warn(f"The annotations is empty for the field {field_name}") - setattr(self, field_name, []) - self.__fields.append(field_name) + warnings.warn(f"The annotations is empty for the field {field}") + setattr(self, field, []) + self.__fields.append(field) continue annotation_types = {type(a) for a in annotations} assert ( len(annotation_types) == 1 - ), f"Annotations in field_name {field_name} more than 1 type: {annotation_types}" + ), f"Annotations in field {field} more than 1 type: {annotation_types}" annotation_type = annotation_types.pop() if annotation_type == SpanGroup: span_groups = self._annotate_span_group( - span_groups=annotations, field_name=field_name + span_groups=annotations, field=field ) elif annotation_type == BoxGroup: # TODO: not good. BoxGroups should be stored on their own, not auto-generating SpanGroups. span_groups = self._annotate_box_group( - box_groups=annotations, field_name=field_name + box_groups=annotations, field=field ) else: raise NotImplementedError( - f"Unsupported annotation type {annotation_type} for {field_name}" + f"Unsupported annotation type {annotation_type} for {field}" ) # register fields - setattr(self, field_name, span_groups) - self.__fields.append(field_name) + setattr(self, field, span_groups) + self.__fields.append(field) def annotate_images( - self, images: Iterable[PILImage], is_overwrite: bool = False + self, images: Iterable[PILImage], is_overwrite: bool = False ) -> None: if not is_overwrite and len(self.images) > 0: raise AssertionError( @@ -122,7 +121,7 @@ def annotate_images( self.images = images def _annotate_span_group( - self, span_groups: List[SpanGroup], field_name: str + self, span_groups: List[SpanGroup], field: str ) -> List[SpanGroup]: """Annotate the Document using a bunch of span groups. It will associate the annotations with the document symbols. @@ -131,15 +130,15 @@ def _annotate_span_group( # 1) add Document to each SpanGroup for span_group in span_groups: - span_group.attach_doc(doc=self) + span_group._attach_doc(doc=self, field=field) # 2) Build fast overlap lookup index - self.__indexers[field_name] = SpanGroupIndexer(span_groups) + self.__indexers[field] = SpanGroupIndexer(span_groups) return span_groups def _annotate_box_group( - self, box_groups: List[BoxGroup], field_name: str + self, box_groups: List[BoxGroup], field: str ) -> List[SpanGroup]: """Annotate the Document using a bunch of box groups. It will associate the annotations with the document symbols. @@ -177,7 +176,7 @@ def _annotate_box_group( derived_span_groups.append( SpanGroup( spans=MergeSpans(list_of_spans=all_token_spans_with_box_group, index_distance=1) - .merge_neighbor_spans_by_symbol_distance(), box_group=box_group, + .merge_neighbor_spans_by_symbol_distance(), box_group=box_group, # id = box_id, ) # TODO Right now we cannot assign the box id, or otherwise running doc.blocks will @@ -195,7 +194,7 @@ def _annotate_box_group( span_group.id = box_id return self._annotate_span_group( - span_groups=derived_span_groups, field_name=field_name + span_groups=derived_span_groups, field=field ) # @@ -245,16 +244,23 @@ def from_json(cls, doc_dict: Dict) -> "Document": ) # 2) convert span group dicts to span gropus - field_name_to_span_groups = {} - for field_name, span_group_dicts in doc_dict.items(): - if field_name not in doc.SPECIAL_FIELDS: + field_to_span_groups = {} + for field, span_group_dicts in doc_dict.items(): + if field not in doc.SPECIAL_FIELDS: span_groups = [ SpanGroup.from_json(span_group_dict=span_group_dict) for span_group_dict in span_group_dicts ] - field_name_to_span_groups[field_name] = span_groups + field_to_span_groups[field] = span_groups # 3) load annotations for each field - doc.annotate(**field_name_to_span_groups) + doc.annotate(**field_to_span_groups) return doc + + def locate_annotation(self, name: AnnotationName) -> Annotation: + candidates = self.__getattribute__(name.field) + matched_annotations = [c for c in candidates if c.id == name.id] + assert len(matched_annotations) <= 1, \ + f"Multiple annotations in field {name.field} with same ID {name.id}" + return matched_annotations[0] diff --git a/mmda/types/names.py b/mmda/types/names.py index 49460dbe..020debb2 100644 --- a/mmda/types/names.py +++ b/mmda/types/names.py @@ -1,6 +1,6 @@ """ -Names of fields, as strings +Names of Annotations, as strings @kylel @@ -11,8 +11,26 @@ Images = 'images' Pages = 'pages' -Tokens = 'tokens' Rows = 'rows' -Sentences = 'sents' Blocks = 'blocks' + +Tokens = 'tokens' Words = 'words' +Sentences = 'sents' +Sections = 'secs' +Paragraphs = 'paras' + +Figures = 'figures' +Tables = 'tables' +Captions = 'captions' + +BibEntries = 'bibs' +CiteMentions = 'cites' +ReferenceMentions = 'refs' + +# singletons +Title = 'title' +Abstract = 'abstract' + +# relations +RefersTo = 'refers_to' \ No newline at end of file diff --git a/tests/test_internal_ai2/test_api.py b/tests/test_internal_ai2/test_api.py index 404f2940..811f8f8c 100644 --- a/tests/test_internal_ai2/test_api.py +++ b/tests/test_internal_ai2/test_api.py @@ -121,3 +121,4 @@ def test_span_group(self): mmda_api.SpanGroup.from_mmda(span_group.to_mmda()), span_group ) + diff --git a/tests/test_types/test_json_conversion.py b/tests/test_types/test_json_conversion.py index e7a5f27d..6cd3e608 100644 --- a/tests/test_types/test_json_conversion.py +++ b/tests/test_types/test_json_conversion.py @@ -8,25 +8,50 @@ import json from pathlib import Path -from mmda.types import BoxGroup, SpanGroup, Document, Metadata +from mmda.types import BoxGroup, SpanGroup, Document, Metadata, Relation from mmda.parsers import PDFPlumberParser - PDFFILEPATH = Path(__file__).parent / "../fixtures/1903.10676.pdf" def test_span_group_conversion(): - sg = SpanGroup(spans=[], id=3, metadata=Metadata.from_json({"text": "test"})) + sg = SpanGroup(spans=[], id=3, metadata=Metadata(text='test')) sg2 = SpanGroup.from_json(sg.to_json()) assert sg2.to_json() == sg.to_json() assert sg2.__dict__ == sg.__dict__ - bg = BoxGroup(boxes=[], metadata=Metadata.from_json({"text": "test", "id": 1})) + bg = BoxGroup(boxes=[], metadata=Metadata(text='test')) bg2 = BoxGroup.from_json(bg.to_json()) assert bg2.to_json() == bg.to_json() assert bg2.__dict__ == bg.__dict__ +def test_relation_conversion(): + r = Relation( + key=SpanGroup(spans=[], id=3, metadata=Metadata(foobar='test'), field='abc'), + value=SpanGroup(spans=[], id=5, metadata=Metadata(foobar='test'), field='xyz'), + id=999, + metadata=Metadata(blabla='something') + ) + + # to & from JSON + r_dict_minimal = { + 'key': 'abc-3', + 'value': 'xyz-5', + 'id': 999, + 'metadata': {'blabla': 'something'} + } + assert r.to_json() == r_dict_minimal + + doc = Document.from_json(doc_dict={ + 'symbols': 'asdfasdf', + 'abc': [{'spans': [], 'id': 3, 'metadata': {'foobar': 'test'}}], + 'xyz': [{'spans': [], 'id': 5, 'metadata': {'foobar': 'test'}}] + }) + assert r_dict_minimal == r.from_json(r_dict_minimal, doc=doc).to_json() + + + def test_doc_conversion(): pdfparser = PDFPlumberParser()