diff --git a/examples/bibliography_extraction/main.py b/examples/bibliography_extraction/main.py index c71c5b9b..53eeddc3 100644 --- a/examples/bibliography_extraction/main.py +++ b/examples/bibliography_extraction/main.py @@ -10,14 +10,14 @@ from mmda.predictors.hf_predictors.vila_predictor import HVILAPredictor from mmda.predictors.tesseract_predictors import TesseractBlockPredictor from mmda.rasterizers.rasterizer import PDF2ImageRasterizer -from mmda.types.annotation import BoxGroup, SpanGroup +from mmda.types.annotation import BoxGroup, Entity from mmda.types.document import Document PDF_PATH = "resources/maml.pdf" -def _clone_span_group(span_group: SpanGroup): - return SpanGroup( +def _clone_span_group(span_group: Entity): + return Entity( spans=span_group.spans, id=span_group.id, text=span_group.text, @@ -46,7 +46,7 @@ def _index_document_pages(document) -> List[PageSpan]: return page_spans -def _find_page_num(span_group: SpanGroup, page_spans: List[PageSpan]) -> int: +def _find_page_num(span_group: Entity, page_spans: List[PageSpan]) -> int: s = min([span.start for span in span_group.spans]) e = max([span.end for span in span_group.spans]) @@ -58,7 +58,7 @@ def _find_page_num(span_group: SpanGroup, page_spans: List[PageSpan]) -> int: def _highest_overlap_block( - token: SpanGroup, blocks: Iterable[BoxGroup] + token: Entity, blocks: Iterable[BoxGroup] ) -> Optional[BoxGroup]: assert len(token.spans) == 1 token_box = token.spans[0].box @@ -78,7 +78,7 @@ def _highest_overlap_block( return found_block -def extract_bibliography_grotoap2(document: Document) -> Iterable[SpanGroup]: +def extract_bibliography_grotoap2(document: Document) -> Iterable[Entity]: """GROTOAP2 has type 1 for REFERENCES""" return [_clone_span_group(sg) for sg in document.preds if sg.type == 1] diff --git a/examples/vila_for_scidoc_parsing/main.py b/examples/vila_for_scidoc_parsing/main.py index c40d3a00..3464bd4f 100644 --- a/examples/vila_for_scidoc_parsing/main.py +++ b/examples/vila_for_scidoc_parsing/main.py @@ -8,7 +8,7 @@ from mmda.parsers.pdfplumber_parser import PDFPlumberParser from mmda.rasterizers.rasterizer import PDF2ImageRasterizer -from mmda.types.annotation import SpanGroup +from mmda.types.annotation import Entity from mmda.predictors.lp_predictors import LayoutParserPredictor from mmda.predictors.hf_predictors.vila_predictor import IVILAPredictor, HVILAPredictor @@ -33,7 +33,7 @@ def draw_tokens( image, - doc_tokens: List[SpanGroup], + doc_tokens: List[Entity], color_map=None, token_boundary_width=0, alpha=0.25, @@ -65,7 +65,7 @@ def draw_tokens( def draw_blocks( image, - doc_tokens: List[SpanGroup], + doc_tokens: List[Entity], color_map=None, token_boundary_width=0, alpha=0.25, diff --git a/examples/vlue_evaluation/main.py b/examples/vlue_evaluation/main.py index 1d76762a..efa671ab 100644 --- a/examples/vlue_evaluation/main.py +++ b/examples/vlue_evaluation/main.py @@ -17,7 +17,7 @@ IVILAPredictor) from mmda.predictors.lp_predictors import LayoutParserPredictor from mmda.rasterizers.rasterizer import PDF2ImageRasterizer -from mmda.types.annotation import SpanGroup +from mmda.types.annotation import Entity from mmda.types.document import Document @@ -31,7 +31,7 @@ class VluePrediction: def _vila_docbank_extract_entities(types: List[str]): - def extractor(doc: Document) -> Dict[str, List[SpanGroup]]: + def extractor(doc: Document) -> Dict[str, List[Entity]]: mapping = { "paragraph": 0, "title": 1, @@ -62,7 +62,7 @@ def extractor(doc: Document) -> Dict[str, List[SpanGroup]]: def _vila_grotoap2_extract_entities(types: List[str]): - def extractor(doc: Document) -> Dict[str, List[SpanGroup]]: + def extractor(doc: Document) -> Dict[str, List[Entity]]: # TODO: Have some sort of unified mapping between this and docbank # TODO: Below title and abstract have been lower-cased to match docbank mapping = { @@ -107,7 +107,7 @@ def vila_prediction( id_: str, doc: Document, vila_predictor: BaseVILAPredictor, # pylint: disable=redefined-outer-name - vila_extractor: Callable[[Document], Dict[str, List[SpanGroup]]], + vila_extractor: Callable[[Document], Dict[str, List[Entity]]], ) -> VluePrediction: # Predict token types span_groups = vila_predictor.predict(doc) diff --git a/src/ai2_internal/api.py b/src/ai2_internal/api.py index 8f2f3cd9..b0c36b73 100644 --- a/src/ai2_internal/api.py +++ b/src/ai2_internal/api.py @@ -115,7 +115,7 @@ class SpanGroup(Annotation): text: Optional[str] @classmethod - def from_mmda(cls, span_group: mmda_ann.SpanGroup) -> "SpanGroup": + def from_mmda(cls, span_group: mmda_ann.Entity) -> "SpanGroup": box_group = ( BoxGroup.from_mmda(span_group.box_group) if span_group.box_group is not None @@ -136,13 +136,13 @@ def from_mmda(cls, span_group: mmda_ann.SpanGroup) -> "SpanGroup": ret.box_group = BoxGroup.from_mmda(span_group.box_group) return ret - def to_mmda(self) -> mmda_ann.SpanGroup: + def to_mmda(self) -> mmda_ann.Entity: metadata = mmda_ann.Metadata.from_json(self.attributes.dict()) if self.type: metadata.type = self.type if self.text: metadata.text = self.text - return mmda_ann.SpanGroup( + return mmda_ann.Entity( metadata=metadata, spans=[span.to_mmda() for span in self.spans], box_group=self.box_group.to_mmda()if self.box_group else None, diff --git a/src/ai2_internal/vila/interface.py b/src/ai2_internal/vila/interface.py index 688d27ae..a0579601 100644 --- a/src/ai2_internal/vila/interface.py +++ b/src/ai2_internal/vila/interface.py @@ -16,7 +16,7 @@ from mmda.predictors.hf_predictors.token_classification_predictor import ( IVILATokenClassificationPredictor, ) -from mmda.types.document import Document, SpanGroup +from mmda.types.document import Document, Entity from mmda.types.image import frombase64 logger = logging.getLogger(__name__) @@ -54,7 +54,7 @@ class Prediction(BaseModel): groups: List[api.SpanGroup] @classmethod - def from_mmda(cls, groups: List[SpanGroup]) -> "Prediction": + def from_mmda(cls, groups: List[Entity]) -> "Prediction": return cls(groups=[api.SpanGroup.from_mmda(grp) for grp in groups]) diff --git a/src/mmda/featurizers/citation_link_featurizers.py b/src/mmda/featurizers/citation_link_featurizers.py index fe7c4919..4f44ff97 100644 --- a/src/mmda/featurizers/citation_link_featurizers.py +++ b/src/mmda/featurizers/citation_link_featurizers.py @@ -4,7 +4,7 @@ from thefuzz import fuzz from typing import List, Tuple, Dict -from mmda.types.annotation import SpanGroup +from mmda.types.annotation import Entity DIGITS = re.compile(r'[0-9]+') @@ -24,7 +24,7 @@ MATCH_FIRST_TOKEN = "match_first_token" class CitationLink: - def __init__(self, mention: SpanGroup, bib: SpanGroup): + def __init__(self, mention: Entity, bib: Entity): self.mention = mention self.bib = bib diff --git a/src/mmda/parsers/grobid_parser.py b/src/mmda/parsers/grobid_parser.py index 8b028c34..1595dc8d 100644 --- a/src/mmda/parsers/grobid_parser.py +++ b/src/mmda/parsers/grobid_parser.py @@ -13,7 +13,7 @@ import json from mmda.parsers.parser import Parser -from mmda.types.annotation import SpanGroup +from mmda.types.annotation import Entity from mmda.types.document import Document from mmda.types.metadata import Metadata from mmda.types.span import Span @@ -22,8 +22,8 @@ NS = {"tei": "http://www.tei-c.org/ns/1.0"} -def _null_span_group() -> SpanGroup: - sg = SpanGroup(spans=[]) +def _null_span_group() -> Entity: + sg = Entity(spans=[]) return sg @@ -97,7 +97,7 @@ def _parse_xml_to_doc(self, xml: str) -> Document: return document - def _get_title(self, root: et.Element) -> SpanGroup: + def _get_title(self, root: et.Element) -> Entity: matches = root.findall(".//tei:titleStmt/tei:title", NS) if len(matches) == 0: @@ -110,10 +110,10 @@ def _get_title(self, root: et.Element) -> SpanGroup: tokens = text.split() spans = _get_token_spans(text, tokens) - sg = SpanGroup(spans=spans, metadata=Metadata(text=text)) + sg = Entity(spans=spans, metadata=Metadata(text=text)) return sg - def _get_abstract(self, root: et.Element, offset: int) -> SpanGroup: + def _get_abstract(self, root: et.Element, offset: int) -> Entity: matches = root.findall(".//tei:profileDesc//tei:abstract//", NS) if len(matches) == 0: @@ -124,5 +124,5 @@ def _get_abstract(self, root: et.Element, offset: int) -> SpanGroup: tokens = text.split() spans = _get_token_spans(text, tokens, offset=offset) - sg = SpanGroup(spans=spans, metadata=Metadata(text=text)) + sg = Entity(spans=spans, metadata=Metadata(text=text)) return sg diff --git a/src/mmda/parsers/pdfplumber_parser.py b/src/mmda/parsers/pdfplumber_parser.py index cdd2cbfb..fe31f29c 100644 --- a/src/mmda/parsers/pdfplumber_parser.py +++ b/src/mmda/parsers/pdfplumber_parser.py @@ -6,7 +6,7 @@ import itertools from mmda.types.span import Span from mmda.types.box import Box -from mmda.types.annotation import SpanGroup +from mmda.types.annotation import Entity from mmda.types.document import Document from mmda.parsers.parser import Parser from mmda.types.names import PagesField, RowsField, SymbolsField, TokensField @@ -195,7 +195,7 @@ def _convert_nested_text_to_doc_json( # 1) build tokens & symbols symbols = "" - token_annos: List[SpanGroup] = [] + token_annos: List[Entity] = [] start = 0 for token_id in range(len(token_dicts) - 1): @@ -210,8 +210,8 @@ def _convert_nested_text_to_doc_json( # 2) make Token end = start + len(token_dict["text"]) - token = SpanGroup(spans=[Span(start=start, end=end, box=token_dict["bbox"])], - id=token_id) + token = Entity(spans=[Span(start=start, end=end, box=token_dict["bbox"])], + id=token_id) token_annos.append(token) # 3) increment whitespace based on Row & Word membership. and build Rows. @@ -228,8 +228,8 @@ def _convert_nested_text_to_doc_json( # handle last token symbols += token_dicts[-1]["text"] end = start + len(token_dicts[-1]["text"]) - token = SpanGroup(spans=[Span(start=start, end=end, box=token_dicts[-1]["bbox"])], - id=len(token_dicts) - 1) + token = Entity(spans=[Span(start=start, end=end, box=token_dicts[-1]["bbox"])], + id=len(token_dicts) - 1) token_annos.append(token) # 2) build rows @@ -237,10 +237,10 @@ def _convert_nested_text_to_doc_json( (token, row_id, page_id) for token, row_id, page_id in zip(token_annos, row_ids, page_ids) ] - row_annos: List[SpanGroup] = [] + row_annos: List[Entity] = [] for row_id, tups in itertools.groupby(iterable=tokens_with_group_ids, key=lambda tup: tup[1]): row_tokens = [token for token, _, _ in tups] - row = SpanGroup( + row = Entity( spans=[ Span( start=row_tokens[0][0].start, @@ -255,10 +255,10 @@ def _convert_nested_text_to_doc_json( row_annos.append(row) # 3) build pages - page_annos: List[SpanGroup] = [] + page_annos: List[Entity] = [] for page_id, tups in itertools.groupby(iterable=tokens_with_group_ids, key=lambda tup: tup[2]): page_tokens = [token for token, _, _ in tups] - page = SpanGroup( + page = Entity( spans=[ Span( start=page_tokens[0][0].start, diff --git a/src/mmda/parsers/symbol_scraper_parser.py b/src/mmda/parsers/symbol_scraper_parser.py index 1369cccf..c1b9eb1a 100644 --- a/src/mmda/parsers/symbol_scraper_parser.py +++ b/src/mmda/parsers/symbol_scraper_parser.py @@ -18,7 +18,7 @@ from mmda.types.span import Span from mmda.types.box import Box -from mmda.types.annotation import SpanGroup +from mmda.types.annotation import Entity from mmda.types.document import Document from mmda.parsers.parser import Parser from mmda.types.names import PagesField, RowsField, SymbolsField, TokensField @@ -192,15 +192,15 @@ def _parse_page_to_row_to_tokens(self, xml_lines: List[str], page_to_metrics: Di def _convert_nested_text_to_doc_json(self, page_to_row_to_tokens: Dict) -> Dict: text = '' - page_annos: List[SpanGroup] = [] - token_annos: List[SpanGroup] = [] - row_annos: List[SpanGroup] = [] + page_annos: List[Entity] = [] + token_annos: List[Entity] = [] + row_annos: List[Entity] = [] start = 0 for page, row_to_tokens in page_to_row_to_tokens.items(): - page_rows: List[SpanGroup] = [] + page_rows: List[Entity] = [] for row, tokens in row_to_tokens.items(): # process tokens in this row - row_tokens: List[SpanGroup] = [] + row_tokens: List[Entity] = [] for k, token in enumerate(tokens): # TODO: this is a graphical token specific to SScraper. We process it here, # instead of in XML so we can still reuse XML cache. But this should be replaced w/ better @@ -210,7 +210,7 @@ def _convert_nested_text_to_doc_json(self, page_to_row_to_tokens: Dict) -> Dict: text += token['text'] end = start + len(token['text']) # make token - token = SpanGroup(spans=[Span(start=start, end=end, box=token['bbox'])]) + token = Entity(spans=[Span(start=start, end=end, box=token['bbox'])]) row_tokens.append(token) token_annos.append(token) if k < len(tokens) - 1: @@ -219,14 +219,14 @@ def _convert_nested_text_to_doc_json(self, page_to_row_to_tokens: Dict) -> Dict: text += '\n' # start newline at end of row start = end + 1 # make row - row = SpanGroup(spans=[ + row = Entity(spans=[ Span(start=row_tokens[0][0].start, end=row_tokens[-1][0].end, box=Box.small_boxes_to_big_box(boxes=[span.box for t in row_tokens for span in t])) ]) page_rows.append(row) row_annos.append(row) # make page - page = SpanGroup(spans=[ + page = Entity(spans=[ Span(start=page_rows[0][0].start, end=page_rows[-1][0].end, box=Box.small_boxes_to_big_box(boxes=[span.box for r in page_rows for span in r])) ]) diff --git a/src/mmda/predictors/heuristic_predictors/dictionary_word_predictor.py b/src/mmda/predictors/heuristic_predictors/dictionary_word_predictor.py index 12eca518..9c298139 100644 --- a/src/mmda/predictors/heuristic_predictors/dictionary_word_predictor.py +++ b/src/mmda/predictors/heuristic_predictors/dictionary_word_predictor.py @@ -32,7 +32,7 @@ from mmda.parsers import PDFPlumberParser from mmda.predictors.base_predictors.base_predictor import BasePredictor from mmda.predictors.heuristic_predictors.whitespace_predictor import WhitespacePredictor -from mmda.types import Metadata, Document, Span, SpanGroup +from mmda.types import Metadata, Document, Span, Entity from mmda.types.names import RowsField, TokensField @@ -102,7 +102,7 @@ def __init__( raise FileNotFoundError(f'{self.dictionary_file_path}') self.whitespace_predictor = WhitespacePredictor() - def predict(self, document: Document) -> List[SpanGroup]: + def predict(self, document: Document) -> List[Entity]: """Get words from a document as a list of SpanGroup. Args: @@ -159,7 +159,7 @@ def predict(self, document: Document) -> List[SpanGroup]: ) # 5) transformation - words: List[SpanGroup] = self._convert_to_words( + words: List[Entity] = self._convert_to_words( document=document, token_id_to_word_id=token_id_to_word_id, word_id_to_text=word_id_to_text @@ -174,7 +174,7 @@ def _precompute_whitespace_tokens(self, document: Document) -> Dict: `whitespace_tokenization` is necessary because lack of whitespace is an indicator that adjacent tokens belong in a word together. """ - _ws_tokens: List[SpanGroup] = self.whitespace_predictor.predict(document=document) + _ws_tokens: List[Entity] = self.whitespace_predictor.predict(document=document) document.annotate(_ws_tokens=_ws_tokens) # token -> ws_tokens @@ -492,7 +492,7 @@ def _convert_to_words( document: Document, token_id_to_word_id, word_id_to_text - ) -> List[SpanGroup]: + ) -> List[Entity]: words = [] tokens_in_word = [document.tokens[0]] current_word_id = 0 @@ -502,7 +502,7 @@ def _convert_to_words( if word_id == current_word_id: tokens_in_word.append(token) else: - word = SpanGroup( + word = Entity( spans=[span for token in tokens_in_word for span in token.spans], id=current_word_id, metadata=Metadata(text=word_id_to_text[current_word_id]) @@ -511,7 +511,7 @@ def _convert_to_words( tokens_in_word = [token] current_word_id = word_id # last bit - word = SpanGroup( + word = Entity( spans=[span for token in tokens_in_word for span in token.spans], id=current_word_id, metadata=Metadata(text=word_id_to_text[current_word_id]) diff --git a/src/mmda/predictors/heuristic_predictors/figure_table_predictors.py b/src/mmda/predictors/heuristic_predictors/figure_table_predictors.py index e12e88cc..6b22e0d8 100644 --- a/src/mmda/predictors/heuristic_predictors/figure_table_predictors.py +++ b/src/mmda/predictors/heuristic_predictors/figure_table_predictors.py @@ -8,7 +8,7 @@ from ai2_internal import api from mmda.predictors.base_predictors.base_heuristic_predictor import BaseHeuristicPredictor -from mmda.types import SpanGroup, BoxGroup, Metadata +from mmda.types import Entity, BoxGroup, Metadata from mmda.types.document import Document from mmda.types.span import Span from mmda.utils.tools import MergeSpans @@ -101,7 +101,7 @@ def _get_object_caption_distance(figure_box: api.Box, caption_box: api.Box) -> f return t_cap - t_fig @staticmethod - def _predict(doc: Document, caption_type: str = 'Figure') -> List[SpanGroup]: + def _predict(doc: Document, caption_type: str = 'Figure') -> List[Entity]: """ Merges boxes corresponding to tokens of table, figure captions. For each page each caption/object create cost matrix which is distance based on get_object_caption_distance. Using linear_sum_assignment find corresponding @@ -143,7 +143,7 @@ def _predict(doc: Document, caption_type: str = 'Figure') -> List[SpanGroup]: row_ind, col_ind = linear_sum_assignment(cost_matrix) for row, col in zip(row_ind, col_ind): - predictions.append(SpanGroup(spans=[Span( + predictions.append(Entity(spans=[Span( start=merged_boxes_caption_dict[page][col].start, end=merged_boxes_caption_dict[page][col].end, box=merged_boxes_caption_dict[page][col].box)], @@ -153,7 +153,7 @@ def _predict(doc: Document, caption_type: str = 'Figure') -> List[SpanGroup]: return predictions @staticmethod - def predict(document: Document) -> Tuple[List[SpanGroup], List[SpanGroup]]: + def predict(document: Document) -> Tuple[List[Entity], List[Entity]]: """ Return tuple caption -> figure, caption -> table Args: diff --git a/src/mmda/predictors/heuristic_predictors/section_header_predictor.py b/src/mmda/predictors/heuristic_predictors/section_header_predictor.py index 48b252b0..507319cf 100644 --- a/src/mmda/predictors/heuristic_predictors/section_header_predictor.py +++ b/src/mmda/predictors/heuristic_predictors/section_header_predictor.py @@ -23,7 +23,7 @@ from mmda.eval.metrics import levenshtein from mmda.predictors.base_predictors.base_predictor import BasePredictor -from mmda.types.annotation import Span, SpanGroup +from mmda.types.annotation import Span, Entity from mmda.types.box import Box from mmda.types.document import Document from mmda.types.metadata import Metadata @@ -63,7 +63,7 @@ def _guess_box_dimensions(spans: List[Span], index: int, outline: OutlineItem) - capture a reasonable area. Args: - page (SpanGroup): The page object from a PDF parser + page (Entity): The page object from a PDF parser index (int): The page index from 0 outline (OutlineMetadata): Rehydrated OutlineMetadata object from querier @@ -166,14 +166,14 @@ def __init__( self._x_threshold = _x_threshold self._y_threshold = _y_threshold - def predict(self, document: Document) -> List[SpanGroup]: + def predict(self, document: Document) -> List[Entity]: """Get section headers in a Document as a list of SpanGroup. Args: doc (Document): The document to process Returns: - list[SpanGroup]: SpanGroups that appear to be headers based on outline + list[Entity]: SpanGroups that appear to be headers based on outline metadata in the PDF (i.e., ToC or sidebar headers). """ if _doc_has_no_outlines(document): @@ -183,10 +183,10 @@ def predict(self, document: Document) -> List[SpanGroup]: outlines = _parse_outline_metadata(document) page_to_outlines = _outlines_to_page_index(outlines) - predictions: List[SpanGroup] = [] + predictions: List[Entity] = [] for i, page in enumerate(document.pages): - tokens: List[SpanGroup] = page.tokens + tokens: List[Entity] = page.tokens spans: List[Span] = [s for t in tokens for s in t.spans] for outline in page_to_outlines[i]: @@ -252,7 +252,7 @@ def predict(self, document: Document) -> List[SpanGroup]: best_candidate = _find_best_candidate(candidates, outline) metadata = Metadata(level=outline.level, title=outline.title) predictions.append( - SpanGroup( + Entity( spans=[x.span for x in best_candidate], metadata=metadata ) ) diff --git a/src/mmda/predictors/heuristic_predictors/sentence_boundary_predictor.py b/src/mmda/predictors/heuristic_predictors/sentence_boundary_predictor.py index b6a18961..7e2a05be 100644 --- a/src/mmda/predictors/heuristic_predictors/sentence_boundary_predictor.py +++ b/src/mmda/predictors/heuristic_predictors/sentence_boundary_predictor.py @@ -4,7 +4,7 @@ import pysbd import numpy as np -from mmda.types.annotation import SpanGroup +from mmda.types.annotation import Entity from mmda.types.span import Span from mmda.types.document import Document from mmda.types.names import PagesField, TokensField, WordsField @@ -120,7 +120,7 @@ def split_token_based_on_sentences_boundary( token_id_start = token_id_end return split - def predict(self, doc: Document) -> List[SpanGroup]: + def predict(self, doc: Document) -> List[Entity]: if hasattr(doc, WordsField): words = [ @@ -149,7 +149,7 @@ def predict(self, doc: Document) -> List[SpanGroup]: ) sentence_spans.append( - SpanGroup(spans=merge_neighbor_spans(all_token_spans)) + Entity(spans=merge_neighbor_spans(all_token_spans)) ) return sentence_spans diff --git a/src/mmda/predictors/heuristic_predictors/whitespace_predictor.py b/src/mmda/predictors/heuristic_predictors/whitespace_predictor.py index afd02df0..f9557976 100644 --- a/src/mmda/predictors/heuristic_predictors/whitespace_predictor.py +++ b/src/mmda/predictors/heuristic_predictors/whitespace_predictor.py @@ -11,7 +11,7 @@ import tokenizers from mmda.predictors.base_predictors.base_predictor import BasePredictor -from mmda.types import Metadata, Document, SpanGroup, Span, BoxGroup +from mmda.types import Metadata, Document, Entity, Span, BoxGroup from mmda.types.names import TokensField @@ -24,7 +24,7 @@ class WhitespacePredictor(BasePredictor): def __init__(self) -> None: self.whitespace_tokenizer = tokenizers.pre_tokenizers.WhitespaceSplit() - def predict(self, document: Document) -> List[SpanGroup]: + def predict(self, document: Document) -> List[Entity]: self._doc_field_checker(document) # 1) whitespace tokenization on symbols. each token is a nested tuple ('text', (start, end)) @@ -42,8 +42,8 @@ def predict(self, document: Document) -> List[SpanGroup]: # chunks.append(chunk) chunks = [] for i, (text, (start, end)) in enumerate(ws_tokens): - chunk = SpanGroup(spans=[Span(start=start, end=end)], - metadata=Metadata(text=text), - id=i) + chunk = Entity(spans=[Span(start=start, end=end)], + metadata=Metadata(text=text), + id=i) chunks.append(chunk) return chunks diff --git a/src/mmda/predictors/hf_predictors/bibentry_predictor/types.py b/src/mmda/predictors/hf_predictors/bibentry_predictor/types.py index ba77607e..916fc998 100644 --- a/src/mmda/predictors/hf_predictors/bibentry_predictor/types.py +++ b/src/mmda/predictors/hf_predictors/bibentry_predictor/types.py @@ -4,7 +4,7 @@ from pydantic import BaseModel -from mmda.types.annotation import SpanGroup +from mmda.types.annotation import Entity class BibEntryLabel(Enum): @@ -50,10 +50,10 @@ class BibEntryPredictionWithSpan(BaseModel): @dataclass class BibEntryStructureSpanGroups: - bib_entry_number: List[SpanGroup] = field(default_factory=list) - bib_entry_authors: List[SpanGroup] = field(default_factory=list) - bib_entry_title: List[SpanGroup] = field(default_factory=list) - bib_entry_venue_or_event: List[SpanGroup] = field(default_factory=list) - bib_entry_year: List[SpanGroup] = field(default_factory=list) - bib_entry_doi: List[SpanGroup] = field(default_factory=list) - bib_entry_url: List[SpanGroup] = field(default_factory=list) + bib_entry_number: List[Entity] = field(default_factory=list) + bib_entry_authors: List[Entity] = field(default_factory=list) + bib_entry_title: List[Entity] = field(default_factory=list) + bib_entry_venue_or_event: List[Entity] = field(default_factory=list) + bib_entry_year: List[Entity] = field(default_factory=list) + bib_entry_doi: List[Entity] = field(default_factory=list) + bib_entry_url: List[Entity] = field(default_factory=list) diff --git a/src/mmda/predictors/hf_predictors/bibentry_predictor/utils.py b/src/mmda/predictors/hf_predictors/bibentry_predictor/utils.py index e847bfeb..21a0b403 100644 --- a/src/mmda/predictors/hf_predictors/bibentry_predictor/utils.py +++ b/src/mmda/predictors/hf_predictors/bibentry_predictor/utils.py @@ -1,6 +1,6 @@ from typing import List -from mmda.types.annotation import SpanGroup +from mmda.types.annotation import Entity from mmda.types.document import Document from mmda.types.span import Span from mmda.predictors.hf_predictors.bibentry_predictor.types import BibEntryPredictionWithSpan, BibEntryStructureSpanGroups @@ -14,7 +14,7 @@ def mk_bib_entry_strings(document: Document) -> List[str]: def map_raw_predictions_to_mmda( - bib_entries: List[SpanGroup], + bib_entries: List[Entity], raw_preds: List[BibEntryPredictionWithSpan] ) -> BibEntryStructureSpanGroups: """ @@ -71,7 +71,7 @@ def map_raw_span(target, raw_span): if cur_pos >= end: break - target.append(SpanGroup(spans=new_spans)) + target.append(Entity(spans=new_spans)) map_raw_span(prediction.bib_entry_number, raw_pred.citation_number) for author in (raw_pred.authors or []): diff --git a/src/mmda/predictors/hf_predictors/mention_predictor.py b/src/mmda/predictors/hf_predictors/mention_predictor.py index 169a2ec5..fcc4679c 100644 --- a/src/mmda/predictors/hf_predictors/mention_predictor.py +++ b/src/mmda/predictors/hf_predictors/mention_predictor.py @@ -7,7 +7,7 @@ import torch from transformers import AutoModelForTokenClassification, AutoTokenizer -from mmda.types.annotation import Annotation, SpanGroup +from mmda.types.annotation import Annotation, Entity from mmda.types.document import Document from mmda.types.span import Span from mmda.parsers.pdfplumber_parser import PDFPlumberParser @@ -65,7 +65,7 @@ def __init__(self, artifacts_dir: str): # https://stackoverflow.com/a/60018731 self.model.eval() # for some reason the onnx version doesnt have an eval() - def predict(self, doc: Document, print_warnings: bool = False) -> List[SpanGroup]: + def predict(self, doc: Document, print_warnings: bool = False) -> List[Entity]: if not hasattr(doc, 'pages'): return [] @@ -75,7 +75,7 @@ def predict(self, doc: Document, print_warnings: bool = False) -> List[SpanGroup spangroups.extend(self.predict_page(page, counter=counter, print_warnings=print_warnings)) return spangroups - def predict_page(self, page: Annotation, counter: Iterator[int], print_warnings: bool = False) -> List[SpanGroup]: + def predict_page(self, page: Annotation, counter: Iterator[int], print_warnings: bool = False) -> List[Entity]: if not hasattr(page, 'tokens'): return [] @@ -126,7 +126,7 @@ def has_label_id(lbls: List[int], want_label_id: int) -> bool: def append_acc(): nonlocal acc if acc: - ret.append(SpanGroup(spans=acc, id=next(counter))) + ret.append(Entity(spans=acc, id=next(counter))) acc = [] for word_id, label_ids in zip(word_ids, word_label_ids): diff --git a/src/mmda/predictors/hf_predictors/span_group_classification_predictor.py b/src/mmda/predictors/hf_predictors/span_group_classification_predictor.py index dcc9957d..aed2a261 100644 --- a/src/mmda/predictors/hf_predictors/span_group_classification_predictor.py +++ b/src/mmda/predictors/hf_predictors/span_group_classification_predictor.py @@ -20,7 +20,7 @@ ) from mmda.types.metadata import Metadata -from mmda.types.annotation import Annotation, Span, SpanGroup +from mmda.types.annotation import Annotation, Span, Entity from mmda.types.document import Document from mmda.predictors.hf_predictors.base_hf_predictor import BaseHFPredictor @@ -241,7 +241,7 @@ def postprocess( else: new_metadata.label = None new_metadata.score = None - new_span_group = SpanGroup( + new_span_group = Entity( spans=span_group.spans, box_group=span_group.box_group, metadata=new_metadata diff --git a/src/mmda/predictors/hf_predictors/token_classification_predictor.py b/src/mmda/predictors/hf_predictors/token_classification_predictor.py index 78e553f7..6496c072 100644 --- a/src/mmda/predictors/hf_predictors/token_classification_predictor.py +++ b/src/mmda/predictors/hf_predictors/token_classification_predictor.py @@ -11,7 +11,7 @@ from mmda.types.metadata import Metadata from mmda.types.names import BlocksField, PagesField, RowsField, TokensField -from mmda.types.annotation import Annotation, Span, SpanGroup +from mmda.types.annotation import Annotation, Span, Entity from mmda.types.document import Document from mmda.predictors.hf_predictors.utils import ( convert_document_page_to_pdf_dict, @@ -97,7 +97,7 @@ def preprocess(self, page: Document, page_width: float, page_height: float) -> D page, page_width=page_width, page_height=page_height ) - def postprocess(self, document: Document, model_predictions) -> List[SpanGroup]: + def postprocess(self, document: Document, model_predictions) -> List[Entity]: token_prediction_spans = convert_sequence_tagging_to_spans(model_predictions) @@ -107,7 +107,7 @@ def postprocess(self, document: Document, model_predictions) -> List[SpanGroup]: start = min([ele.start for ele in cur_spans]) end = max([ele.end for ele in cur_spans]) - sg = SpanGroup(spans=[Span(start, end)], metadata=Metadata(type=label)) + sg = Entity(spans=[Span(start, end)], metadata=Metadata(type=label)) prediction_spans.append(sg) return prediction_spans diff --git a/src/mmda/predictors/hf_predictors/utils.py b/src/mmda/predictors/hf_predictors/utils.py index d631fe0a..34df34be 100644 --- a/src/mmda/predictors/hf_predictors/utils.py +++ b/src/mmda/predictors/hf_predictors/utils.py @@ -1,4 +1,4 @@ -from mmda.types.annotation import SpanGroup +from mmda.types.annotation import Entity from typing import List, Tuple, Dict import itertools @@ -39,7 +39,7 @@ def shift_index_sequence_to_zero_start(sequence): return [i - sequence_start for i in sequence] -def get_visual_group_id(token: SpanGroup, field_name: str, defaults=-1) -> int: +def get_visual_group_id(token: Entity, field_name: str, defaults=-1) -> int: if not hasattr(token, field_name): return defaults field_value = getattr(token, field_name) diff --git a/src/mmda/predictors/hf_predictors/vila_predictor.py b/src/mmda/predictors/hf_predictors/vila_predictor.py index 09b09f1c..08ae32d0 100644 --- a/src/mmda/predictors/hf_predictors/vila_predictor.py +++ b/src/mmda/predictors/hf_predictors/vila_predictor.py @@ -19,7 +19,7 @@ from mmda.types.metadata import Metadata from mmda.types.names import PagesField, RowsField, TokensField -from mmda.types.annotation import Annotation, Span, SpanGroup +from mmda.types.annotation import Annotation, Span, Entity from mmda.types.document import Document from mmda.predictors.hf_predictors.utils import ( convert_document_page_to_pdf_dict, @@ -153,7 +153,7 @@ def get_true_token_level_category_prediction( def postprocess( self, document, pdf_dict, model_inputs, model_predictions - ) -> List[SpanGroup]: + ) -> List[Entity]: true_token_prediction = self.get_true_token_level_category_prediction( pdf_dict, model_inputs, model_predictions @@ -168,7 +168,7 @@ def postprocess( start = min([ele.start for ele in cur_spans]) end = max([ele.end for ele in cur_spans]) - sg = SpanGroup(spans=[Span(start, end)], metadata=Metadata(type=label)) + sg = Entity(spans=[Span(start, end)], metadata=Metadata(type=label)) prediction_spans.append(sg) return prediction_spans diff --git a/src/mmda/types/__init__.py b/src/mmda/types/__init__.py index d0f3929c..648d3689 100644 --- a/src/mmda/types/__init__.py +++ b/src/mmda/types/__init__.py @@ -1,5 +1,5 @@ from mmda.types.document import Document -from mmda.types.annotation import SpanGroup, BoxGroup +from mmda.types.annotation import Entity, BoxGroup from mmda.types.span import Span from mmda.types.box import Box from mmda.types.image import PILImage @@ -7,7 +7,7 @@ __all__ = [ 'Document', - 'SpanGroup', + 'Entity', 'BoxGroup', 'Span', 'Box', diff --git a/src/mmda/types/annotation.py b/src/mmda/types/annotation.py index 4857df5c..33bf9e0c 100644 --- a/src/mmda/types/annotation.py +++ b/src/mmda/types/annotation.py @@ -1,15 +1,17 @@ """ -Annotations are objects that are 'aware' of the Document +Annotations are objects that are 'aware' of the Document. For example, imagine an entity +in a document; representing it as an Annotation data type would allow you to access the +Document object directly from within the Entity itself. -Collections of Annotations are how one constructs a new -Iterable of Group-type objects within the Document +@kylel """ -import warnings + +import logging + from abc import abstractmethod -from copy import deepcopy -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union from mmda.types.box import Box from mmda.types.metadata import Metadata @@ -18,10 +20,7 @@ if TYPE_CHECKING: from mmda.types.document import Document - -__all__ = ["Annotation", "BoxGroup", "SpanGroup", "Relation"] - - +__all__ = ["Annotation", "BoxGroup", "Entity", "Relation"] def warn_deepcopy_of_annotation(obj: "Annotation") -> None: """Warns when a deepcopy is performed on an Annotation.""" @@ -34,19 +33,44 @@ def warn_deepcopy_of_annotation(obj: "Annotation") -> None: warnings.warn(msg, UserWarning, stacklevel=2) - class Annotation: - """Annotation is intended for storing model predictions for a document.""" + """Annotation allows us to layer different model predictions on a single document.""" - def __init__( - self, - id: Optional[int] = None, - doc: Optional['Document'] = None, - metadata: Optional[Metadata] = None - ): - self.id = id - self.doc = doc - self.metadata = metadata if metadata else Metadata() + @abstractmethod + def __init__(self): + self._id: Optional[int] = None + self._doc: Optional['Document'] = None + logging.warning('Unless testing or developing, we dont recommend creating Annotations ' + 'manually. Annotations need to store things like `id` and references ' + 'to a `Document` to be valuable. These are all handled automatically in ' + '`Parsers` and `Predictors`.') + + @property + def doc(self) -> Optional['Document']: + return self._doc + + @doc.setter + def doc(self, doc: Document) -> None: + """This method attaches a Document to this Annotation, allowing the Annotation + to access things beyond itself within the Document (e.g. neighbors)""" + if self.doc: + raise AttributeError("This annotation already has an attached document") + self._doc = doc + + @property + def id(self) -> Optional[int]: + return self._id + + @id.setter + def id(self, id: int) -> None: + """This method assigns an ID to an Annotation. Requires a Document to be attached + to this Annotation. ID basically gives the Annotation itself awareness of its + position within the broader Document.""" + if self.id: + raise AttributeError("This annotation already has an ID") + if not self.doc: + raise AttributeError('This annotation is missing a Document') + self._id = id @abstractmethod def to_json(self) -> Dict: @@ -57,14 +81,8 @@ def to_json(self) -> Dict: def from_json(cls, annotation_dict: Dict) -> "Annotation": pass - def attach_doc(self, doc: "Document") -> None: - if not self.doc: - self.doc = doc - else: - raise AttributeError("This annotation already has an attached document") - - # TODO[kylel] - comment explaining def __getattr__(self, field: str) -> List["Annotation"]: + """This method """ if self.doc is None: raise ValueError("This annotation is not attached to a document") @@ -77,213 +95,89 @@ def __getattr__(self, field: str) -> List["Annotation"]: return self.__getattribute__(field) - -class BoxGroup(Annotation): - def __init__( - self, - boxes: List[Box], - id: Optional[int] = None, - doc: Optional['Document'] = None, - metadata: Optional[Metadata] = None, - ): - self.boxes = boxes - super().__init__(id=id, doc=doc, metadata=metadata) - - def to_json(self) -> Dict: - box_group_dict = dict( - boxes=[box.to_json() for box in self.boxes], - id=self.id, - metadata=self.metadata.to_json() - ) - return { - key: value for key, value in box_group_dict.items() if value - } # only serialize non-null values - - @classmethod - def from_json(cls, box_group_dict: Dict) -> "BoxGroup": - - if "metadata" in box_group_dict: - metadata_dict = box_group_dict["metadata"] - else: - # this fallback is necessary to ensure compatibility with box - # groups that were create before the metadata migration and - # therefore have "type" in the root of the json dict instead. - metadata_dict = { - "type": box_group_dict.get("type", None) - } - - return cls( - boxes=[ - Box.from_json(box_dict=box_dict) - # box_group_dict["boxes"] might not be present since we - # minimally serialize when running to_json() - for box_dict in box_group_dict.get("boxes", []) - ], - id=box_group_dict.get("id", None), - metadata=Metadata.from_json(metadata_dict), - ) - - def __getitem__(self, key: int): - return self.boxes[key] - - def __deepcopy__(self, memo): - warn_deepcopy_of_annotation(self) - - box_group = BoxGroup( - boxes=deepcopy(self.boxes, memo), - id=self.id, - metadata=deepcopy(self.metadata, memo) - ) - - # Don't copy an attached document - box_group.doc = self.doc - - return box_group - - @property - def type(self) -> str: - return self.metadata.get("type", None) - - @type.setter - def type(self, type: Union[str, None]) -> None: - self.metadata.type = type - - -class SpanGroup(Annotation): - +class Entity(Annotation): def __init__( self, spans: List[Span], - box_group: Optional[BoxGroup] = None, - id: Optional[int] = None, - doc: Optional['Document'] = None, - metadata: Optional[Metadata] = None, + boxes: Optional[List[Box]] = None, + metadata: Optional[Metadata] = None ): self.spans = spans - self.box_group = box_group - super().__init__(id=id, doc=doc, metadata=metadata) + self.boxes = boxes + self.metadata = metadata if metadata else Metadata + super().__init__() @property def symbols(self) -> List[str]: - if self.doc is not None: - return [ - self.doc.symbols[span.start: span.end] for span in self.spans - ] - else: - return [] - - def annotate( - self, is_overwrite: bool = False, **kwargs: Iterable["Annotation"] - ) -> None: if self.doc is None: - raise ValueError("SpanGroup has no attached document!") + raise ValueError(f'No document attached.') + return [self.doc.symbols[span.start: span.end] for span in self.spans] - key_remaps = {k: v for k, v in kwargs.items()} + @property + def text(self) -> str: + maybe_text = self.metadata.get("text", None) + if maybe_text is None: + return " ".join(self.symbols) + return maybe_text - self.doc.annotate(is_overwrite=is_overwrite, **key_remaps) + @text.setter + def text(self, text: Union[str, None]) -> None: + self.metadata.text = text def to_json(self) -> Dict: - span_group_dict = dict( + entity_dict = dict( spans=[span.to_json() for span in self.spans], - id=self.id, - metadata=self.metadata.to_json(), - box_group=self.box_group.to_json() if self.box_group else None + boxes=[box.to_json() for box in self.boxes] if self.boxes else None, + metadata=self.metadata.to_json() ) - return { - key: value - for key, value in span_group_dict.items() - if value is not None - } # only serialize non-null values + # only serialize non-null values + return {k: v for k, v in entity_dict.items() if v is not None} @classmethod - def from_json(cls, span_group_dict: Dict) -> "SpanGroup": - box_group_dict = span_group_dict.get("box_group") - if box_group_dict: - box_group = BoxGroup.from_json(box_group_dict=box_group_dict) - else: - box_group = None - - if "metadata" in span_group_dict: - metadata_dict = span_group_dict["metadata"] - else: - # this fallback is necessary to ensure compatibility with span - # groups that were create before the metadata migration and - # therefore have "id", "type" in the root of the json dict instead. - metadata_dict = { - "type": span_group_dict.get("type", None), - "text": span_group_dict.get("text", None) - } - + def from_json(cls, entity_dict: Dict) -> "Entity": return cls( - spans=[ - Span.from_json(span_dict=span_dict) - for span_dict in span_group_dict["spans"] - ], - id=span_group_dict.get("id", None), - metadata=Metadata.from_json(metadata_dict), - box_group=box_group, + spans=[Span.from_json(span_dict=span_dict) for span_dict in entity_dict["spans"]], + boxes=[Box.from_json(box_dict=box_dict) for box_dict in entity_dict['boxes']] + if entity_dict.get('boxes') else None, + metadata=Metadata.from_json(entity_dict['metadata']) + if entity_dict.get('metadata') else None ) - def __getitem__(self, key: int): - return self.spans[key] - @property - def start(self) -> Union[int, float]: - return ( - min([span.start for span in self.spans]) - if len(self.spans) > 0 - else float("-inf") - ) + def start(self) -> int: + return min([span.start for span in self.spans]) @property - def end(self) -> Union[int, float]: - return ( - max([span.end for span in self.spans]) - if len(self.spans) > 0 - else float("inf") - ) + def end(self) -> int: + return max([span.end for span in self.spans]) - def __lt__(self, other: "SpanGroup"): - if self.id and other.id: - return self.id < other.id - else: - return self.start < other.start - def __deepcopy__(self, memo): - warn_deepcopy_of_annotation(self) +class Relation(Annotation): + def __init__( + self, + source: Entity, + target: Entity, + metadata: Optional[Metadata] = None + ): + self.source = source + self.target = target + self.metadata = metadata if metadata else Metadata + super().__init__() - span_group = SpanGroup( - spans=deepcopy(self.spans, memo), - id=self.id, - metadata=deepcopy(self.metadata, memo), - box_group=deepcopy(self.box_group, memo) + def to_json(self) -> Dict: + relation_dict = dict( + source=Entity.id, + target=Entity.id, + metadata=self.metadata.to_json() ) + # only serialize non-null values + return {k: v for k, v in relation_dict.items() if v is not None} - # Don't copy an attached document - span_group.doc = self.doc - - return span_group - - @property - def type(self) -> str: - return self.metadata.get("type", None) - - @type.setter - def type(self, type: Union[str, None]) -> None: - self.metadata.type = type - - @property - def text(self) -> str: - maybe_text = self.metadata.get("text", None) - if maybe_text is None: - return " ".join(self.symbols) - return maybe_text - - @text.setter - def text(self, text: Union[str, None]) -> None: - self.metadata.text = text - - - -class Relation(Annotation): - pass \ No newline at end of file + @classmethod + def from_json(cls, relation_dict: Dict, doc: Document) -> "Relation": + return cls( + source=None, + target=None, + metadata=Metadata.from_json(relation_dict['metadata']) + if relation_dict.get('metadata') else None + ) diff --git a/src/mmda/types/document.py b/src/mmda/types/document.py index 5793237c..e1036a53 100644 --- a/src/mmda/types/document.py +++ b/src/mmda/types/document.py @@ -1,6 +1,6 @@ """ - +@kylel """ @@ -8,18 +8,18 @@ import warnings from typing import Dict, Iterable, List, Optional -from mmda.types.annotation import Annotation, BoxGroup, SpanGroup +from mmda.types.annotation import Annotation, BoxGroup, Entity from mmda.types.image import PILImage from mmda.types.indexers import Indexer, SpanGroupIndexer from mmda.types.metadata import Metadata -from mmda.types.names import ImagesField, MetadataField, SymbolsField +from mmda.types.names import ( + ImagesField, MetadataField, SymbolsField, EntitiesField, RelationsField +) from mmda.utils.tools import MergeSpans, allocate_overlapping_tokens_for_box class Document: - - SPECIAL_FIELDS = [SymbolsField, ImagesField, MetadataField] - UNALLOWED_FIELD_NAMES = ["fields"] + SPECIAL_FIELDS = [SymbolsField, ImagesField, MetadataField, EntitiesField, RelationsField] def __init__(self, symbols: str, metadata: Optional[Metadata] = None): self.symbols = symbols @@ -30,11 +30,12 @@ def __init__(self, symbols: str, metadata: Optional[Metadata] = None): @property def fields(self) -> List[str]: + """Names of all Entities or Relations one can """ return self.__fields # TODO: extend implementation to support DocBoxGroup def find_overlapping(self, query: Annotation, field_name: str) -> List[Annotation]: - if not isinstance(query, SpanGroup): + if not isinstance(query, Entity): raise NotImplementedError( f"Currently only supports query of type SpanGroup" ) @@ -46,7 +47,7 @@ def add_metadata(self, **kwargs): self.metadata.set(k, value) def annotate( - self, is_overwrite: bool = False, **kwargs: Iterable[Annotation] + self, is_overwrite: bool = False, **kwargs: Iterable[Annotation] ) -> None: """Annotate the fields for document symbols (correlating the annotations with the symbols) and store them into the papers. @@ -87,7 +88,7 @@ def annotate( ), f"Annotations in field_name {field_name} more than 1 type: {annotation_types}" annotation_type = annotation_types.pop() - if annotation_type == SpanGroup: + if annotation_type == Entity: span_groups = self._annotate_span_group( span_groups=annotations, field_name=field_name ) @@ -110,9 +111,8 @@ def remove(self, field_name: str): self.__fields = [f for f in self.__fields if f != field_name] del self.__indexers[field_name] - def annotate_images( - self, images: Iterable[PILImage], is_overwrite: bool = False + self, images: Iterable[PILImage], is_overwrite: bool = False ) -> None: if not is_overwrite and len(self.images) > 0: raise AssertionError( @@ -134,12 +134,12 @@ def annotate_images( self.images = images def _annotate_span_group( - self, span_groups: List[SpanGroup], field_name: str - ) -> List[SpanGroup]: + self, span_groups: List[Entity], field_name: str + ) -> List[Entity]: """Annotate the Document using a bunch of span groups. It will associate the annotations with the document symbols. """ - assert all([isinstance(group, SpanGroup) for group in span_groups]) + assert all([isinstance(group, Entity) for group in span_groups]) # 1) add Document to each SpanGroup for span_group in span_groups: @@ -151,8 +151,8 @@ def _annotate_span_group( return span_groups def _annotate_box_group( - self, box_groups: List[BoxGroup], field_name: str - ) -> List[SpanGroup]: + self, box_groups: List[BoxGroup], field_name: str + ) -> List[Entity]: """Annotate the Document using a bunch of box groups. It will associate the annotations with the document symbols. """ @@ -187,7 +187,7 @@ def _annotate_box_group( all_token_spans_with_box_group.extend(tokens_in_box) derived_span_groups.append( - SpanGroup( + Entity( spans=MergeSpans( list_of_spans=all_token_spans_with_box_group, index_distance=1 ).merge_neighbor_spans_by_symbol_distance(), @@ -267,7 +267,7 @@ def from_json(cls, doc_dict: Dict) -> "Document": for field_name, span_group_dicts in doc_dict.items(): if field_name not in doc.SPECIAL_FIELDS: span_groups = [ - SpanGroup.from_json(span_group_dict=span_group_dict) + Entity.from_json(entity_dict=span_group_dict) for span_group_dict in span_group_dicts ] field_name_to_span_groups[field_name] = span_groups diff --git a/src/mmda/types/indexers.py b/src/mmda/types/indexers.py index beb12b1f..050026aa 100644 --- a/src/mmda/types/indexers.py +++ b/src/mmda/types/indexers.py @@ -9,7 +9,7 @@ from abc import abstractmethod from dataclasses import dataclass, field -from mmda.types.annotation import SpanGroup, Annotation +from mmda.types.annotation import Entity, Annotation from ncls import NCLS import numpy as np import pandas as pd @@ -41,7 +41,7 @@ class SpanGroupIndexer(Indexer): Volume 23, Issue 11, 1 June 2007, Pages 1386–1393, https://doi.org/10.1093/bioinformatics/btl647 """ - def __init__(self, span_groups: List[SpanGroup]) -> None: + def __init__(self, span_groups: List[Entity]) -> None: starts = [] ends = [] ids = [] @@ -74,8 +74,8 @@ def _ensure_disjoint(self) -> None: f"Detected overlap with existing SpanGroup(s) {matches} for {span_group}" ) - def find(self, query: SpanGroup) -> List[SpanGroup]: - if not isinstance(query, SpanGroup): + def find(self, query: Entity) -> List[Entity]: + if not isinstance(query, Entity): raise ValueError(f'SpanGroupIndexer only works with `query` that is SpanGroup type') if not query.spans: diff --git a/src/mmda/types/names.py b/src/mmda/types/names.py index bc0ef1a9..8e46f005 100644 --- a/src/mmda/types/names.py +++ b/src/mmda/types/names.py @@ -10,6 +10,8 @@ SymbolsField = "symbols" MetadataField = "metadata" ImagesField = "images" +EntitiesField = 'entities' +RelationsField = 'relations' PagesField = "pages" TokensField = "tokens" diff --git a/src/mmda/types/old/annotations.old.py b/src/mmda/types/old/annotations.old.py deleted file mode 100644 index fbe069f7..00000000 --- a/src/mmda/types/old/annotations.old.py +++ /dev/null @@ -1,76 +0,0 @@ -""" - -Dataclass for representing annotations on documents - -@kylel - -""" - -from typing import List, Optional - -import json - - -from mmda.types.span import Span -from mmda.types.boundingbox import BoundingBox - - -class Annotation: - def to_json(self): - raise NotImplementedError - - def __repr__(self): - return json.dumps(self.to_json()) - - - -class SpanAnnotation(Annotation): - def __init__(self, span: Span, label: str): - self.span = span - self.label = label - - def to_json(self): - return {'span': self.span.to_json(), 'label': self.label} - - def __contains__(self, i: int) -> bool: - return i in self.span - - -class BoundingBoxAnnotation(Annotation): - def __init__(self, bbox: BoundingBox, label: str): - self.bbox = bbox - self.label = label - - def to_json(self): - return {'bbox': self.bbox.to_json(), 'label': self.label} - - - -if __name__ == '__main__': - - # In this example, we construct a sequence tagger training dataset using these classes. - - text = 'I live in New York. I read the New York Times.' - tokens = [(0, 1), (2, 6), (7, 9), (10, 13), (14, 18), (18, 19), - (20, 21), (22, 26), (27, 30), (31, 34), (35, 39), (40, 45), (45, 46)] - tokens = [Span(start=start, end=end) for start, end in tokens] - for token in tokens: - print(text[token.start:token.end]) - - tags = [Span(start=10, end=19, attr=['entity']), Span(start=31, end=46, attr=['entity'])] - for tag in tags: - print(f'{text[tag.start:tag.end]}\t{tag.tags}') - - def get_label(i: int, tags: List[SpanAnnotation]) -> Optional[str]: - for tag in tags: - if i in tag: - return tag.label - return None - - training_data = [] - for i, token in enumerate(tokens): - tag = get_label(i=i, tags=tags) - if tag: - training_data.append((token, 1)) - else: - training_data.append((token, 0)) diff --git a/src/mmda/types/old/boundingbox.old.py b/src/mmda/types/old/boundingbox.old.py deleted file mode 100644 index 102c28d1..00000000 --- a/src/mmda/types/old/boundingbox.old.py +++ /dev/null @@ -1,69 +0,0 @@ -""" - -Dataclass for doing stuff on bounding boxes - -@kyle - -""" - -from typing import List, Dict - -import json - -class BoundingBox: - def __init__(self, l: float, t: float, w: float, h: float, page: int): - """Assumes x=0.0 and y=0.0 is the top-left of the page, and - x=1.0 and y =1.0 as the bottom-right of the page""" - if l < 0.0 or l > 1.0: - raise ValueError(f'l={l} is not within 0.0~1.0') - if t < 0.0 or t > 1.0: - raise ValueError(f't={t} is not within 0.0~1.0') - if l + w < 0.0 or l + w > 1.0: - raise ValueError(f'l+w={l+w} is not within 0.0~1.0') - if t + h < 0.0 or t + h > 1.0: - raise ValueError(f't+h={t+h} is not within 0.0~1.0') - self.l = l - self.t = t - self.w = w - self.h = h - self.page = page - - @classmethod - def from_xyxy(cls, x0: float, y0: float, x1: float, y1: float, - page_height: float, page_width: float) -> 'BoundingBox': - """Assumes (x0,y0) is top-left of box and (x1,y1) is bottom-right of box - where x=0.0 and y=0.0 is top-left of the page""" - raise NotImplementedError - - @classmethod - def from_null(cls): - """Creates an empty bbox; mostly useful for quick tests""" - bbox = cls.__new__(cls) - bbox.l = None - bbox.t = None - bbox.w = None - bbox.h = None - bbox.page = None - return bbox - - @classmethod - def from_json(cls, bbox_json: Dict) -> 'BoundingBox': - l, t, w, h, page = bbox_json - bbox = BoundingBox(l=l, t=t, w=w, h=h, page=page) - return bbox - - def to_json(self): - return [self.l, self.t, self.w, self.h, self.page] - - def __repr__(self): - return json.dumps(self.to_json()) - - @classmethod - def union_bboxes(cls, bboxes: List['BoundingBox']) -> 'BoundingBox': - if len({bbox.page for bbox in bboxes}) != 1: - raise ValueError(f'Bboxes not all on same page: {bboxes}') - x1 = min([bbox.l for bbox in bboxes]) - y1 = min([bbox.t for bbox in bboxes]) - x2 = max([bbox.l + bbox.w for bbox in bboxes]) - y2 = max([bbox.t + bbox.h for bbox in bboxes]) - return BoundingBox(page=bboxes[0].page, l=x1, t=y1, w=x2 - x1, h=y2 - y1) diff --git a/src/mmda/types/old/document.old.py b/src/mmda/types/old/document.old.py deleted file mode 100644 index 767b5a3e..00000000 --- a/src/mmda/types/old/document.old.py +++ /dev/null @@ -1,272 +0,0 @@ -""" - -Dataclass for representing a document and all its constituents - -@kylel - -""" - -from typing import List, Optional, Dict, Tuple, Type - -from intervaltree import IntervalTree - -from mmda.types.boundingbox import BoundingBox -from mmda.types.annotations import Annotation, SpanAnnotation, BoundingBoxAnnotation -from mmda.types.image import Image -from mmda.types.span import Span - - -Text = 'text' -Page = 'page' -Token = 'token' -Row = 'row' -Sent = 'sent' -Block = 'block' -DocImage = 'image' # Conflicting the PIL Image naming - - -class Document: - - valid_types = [Page, Token, Row, Sent, Block] - - def __init__(self, text: str): - - self.text = text - - # TODO: if have span_type Map, do still need these? - self._pages: List[Span] = [] - self._tokens: List[Span] = [] - self._rows: List[Span] = [] - self._sents: List[Span] = [] - self._blocks: List[Span] = [] - self._images: List["PIL.Image"] = [] - - self._span_type_to_spans: Dict[Type, List[Span]] = { - Page: self._pages, - Token: self._tokens, - Row: self._rows, - Sent: self._sents, - Block: self._blocks - } - - self._page_index: IntervalTree = IntervalTree() - self._token_index: IntervalTree = IntervalTree() - self._row_index: IntervalTree = IntervalTree() - self._sent_index: IntervalTree = IntervalTree() - self._block_index: IntervalTree = IntervalTree() - - self._span_type_to_index: Dict[Type, IntervalTree] = { - Page: self._page_index, - Token: self._token_index, - Row: self._row_index, - Sent: self._sent_index, - Block: self._block_index - } - - @classmethod - def from_json(cls, doc_json: Dict) -> 'Document': - doc = Document(text=doc_json[Text]) - pages = [] - tokens = [] - rows = [] - sents = [] - blocks = [] - - for span_type in cls.valid_types: - if span_type in doc_json: - doc_spans = [DocSpan.from_span(span=Span.from_json(span_json=span_json), doc=doc, span_type=span_type) - for span_json in doc_json[span_type]] - if span_type == Page: - pages = doc_spans - elif span_type == Token: - tokens = doc_spans - elif span_type == Row: - rows = doc_spans - elif span_type == Sent: - sents = doc_spans - elif span_type == Block: - blocks = doc_spans - else: - raise Exception(f'Should never reach here') - - images = [Image.frombase64(image_str) for image_str in doc_json.get(DocImage,[])] - - doc.load(pages=pages, tokens=tokens, rows=rows, sents=sents, blocks=blocks, images=images) - return doc - - # TODO: consider simpler more efficient method (e.g. JSONL; text) - def to_json(self) -> Dict: - return { - Text: self.text, - Page: [page.to_json(exclude=['text', 'type']) for page in self.pages], - Token: [token.to_json(exclude=['text', 'type']) for token in self.tokens], - Row: [row.to_json(exclude=['text', 'type']) for row in self.rows], - Sent: [sent.to_json(exclude=['text', 'type']) for sent in self.sents], - Block: [block.to_json(exclude=['text', 'type']) for block in self.blocks], - DocImage: [image.tobase64() for image in self.images] - } - - # - # methods for building Document - # - def _build_span_index(self, spans: List[Span]) -> IntervalTree: - """Builds index for a collection of spans""" - index = IntervalTree() - for span in spans: - # constraint - all spans disjoint - existing = index[span.start:span.end] - if existing: - raise ValueError(f'Existing {existing} when attempting index {span}') - # add to index - index[span.start:span.end] = span - return index - - def _build_span_type_to_spans(self): - self._span_type_to_spans: Dict[Type, List[Span]] = { - Page: self._pages, - Token: self._tokens, - Row: self._rows, - Sent: self._sents, - Block: self._blocks - } - - def _build_span_type_to_index(self): - self._span_type_to_index: Dict[Type, IntervalTree] = { - Page: self._page_index, - Token: self._token_index, - Row: self._row_index, - Sent: self._sent_index, - Block: self._block_index - } - - def load(self, pages: Optional[List[Span]] = None, - tokens: Optional[List[Span]] = None, - rows: Optional[List[Span]] = None, - sents: Optional[List[Span]] = None, - blocks: Optional[List[Span]] = None, - images: Optional[List["PIL.Image"]] = None): - - if pages: - self._pages = pages - self._page_index = self._build_span_index(spans=pages) - if tokens: - self._tokens = tokens - self._token_index = self._build_span_index(spans=tokens) - if rows: - self._rows = rows - self._row_index = self._build_span_index(spans=rows) - if sents: - self._sents = sents - self._sent_index = self._build_span_index(spans=sents) - if blocks: - self._blocks = blocks - self._block_index = self._build_span_index(spans=blocks) - if images: - self._images = images - - self._build_span_type_to_spans() - self._build_span_type_to_index() - - # - # don't mess with Document internals - # - @property - def pages(self) -> List[Span]: - return self._pages - - @property - def tokens(self) -> List[Span]: - return self._tokens - - @property - def rows(self) -> List[Span]: - return self._rows - - @property - def sents(self) -> List[Span]: - return self._sents - - @property - def blocks(self) -> List[Span]: - return self._blocks - - @property - def images(self) -> List["PIL.Image"]: - return self._images - - # - # methods for using Document - # - # TODO: how should `containment` lookups be handled in this library? intersection is symmetric but containment isnt - # TODO: @lru.cache or some memoization might improve performance - def find(self, query: Span, types: str) -> List[Span]: - index = self._span_type_to_index[types] - return sorted([interval.data for interval in index[query.start:query.end]]) - - # TODO: what happens to the document data when annotate? - def annotate(self, annotations: List[Annotation]): - for annotation in annotations: - if isinstance(annotation, SpanAnnotation): - pass - elif isinstance(annotation, BoundingBoxAnnotation): - pass - else: - pass - raise NotImplementedError - - - - -class DocSpan(Span): - def __init__(self, start: int, end: int, doc: Document, - id: Optional[int] = None, type: Optional[str] = None, - text: Optional[str] = None, bbox: Optional[BoundingBox] = None): - super().__init__(start=start, end=end, id=id, type=type, text=text, bbox=bbox) - self.doc = doc - - @property - def tokens(self) -> List: - if self.type == 'token': - raise ValueError(f'{self} is a Token and cant lookup other Tokens') - else: - return self.doc.find(query=self, types='token') - - @property - def pages(self) -> List: - if self.type == 'page': - raise ValueError(f'{self} is a Page and cant lookup other Pages') - else: - return self.doc.find(query=self, types='page') - - @property - def rows(self) -> List: - if self.type == 'row': - raise ValueError(f'{self} is a Row and cant lookup other Rows') - else: - return self.doc.find(query=self, types='row') - - @property - def sents(self) -> List: - if self.type == 'sent': - raise ValueError(f'{self} is a Sentence and cant lookup other Sentences') - else: - return self.doc.find(query=self, types='sent') - - @property - def blocks(self) -> List: - if self.type == 'block': - raise ValueError(f'{self} is a Block and cant lookup other Blocks') - else: - return self.doc.find(query=self, types='block') - - @classmethod - def from_span(cls, span: Span, doc: Document, span_type: str) -> 'DocSpan': - doc_span = cls(start=span.start, end=span.end, doc=doc, - type=span.type, id=span.id, text=span.text, bbox=span.bbox) - # these two fields are optional for `Span` & not often serialized in span_jsons, but are - # critical for DocSpan methods to work properly - if not doc_span.type: - doc_span.type = span_type - if not doc_span.text: - doc_span.text = doc.text[doc_span.start:doc_span.end] - return doc_span diff --git a/src/mmda/types/old/document_elements.py b/src/mmda/types/old/document_elements.py deleted file mode 100644 index 80c50a58..00000000 --- a/src/mmda/types/old/document_elements.py +++ /dev/null @@ -1,67 +0,0 @@ -""" - - -""" - -# TODO[kylel] not sure this class needs to exist; seems extra boilerplate for no benefit - - - -from typing import List, Optional, Dict, Tuple, Type -from abc import abstractmethod -from dataclasses import dataclass, field - - - -@dataclass -class DocumentElement: - """DocumentElement is the base class for all children objects - of a Document. It defines the necessary APIs for manipulating - the children objects. - """ - - @abstractmethod - def to_json(self) -> Dict: - pass - - # TODO: unclear if should be `annotations` or `annotation` - @abstractmethod - @classmethod - def load(cls, field_name: str, annotations: List["Annotation"], document: Optional["Document"] = None): - pass - - -@dataclass -class DocumentPageSymbols(DocumentElement): - """Storing the symbols of a page.""" - - symbols: str - - # TODO: Add support for symbol bounding boxes and style - def __getitem__(self, key): - return self.symbols[key] - - def to_json(self): - return self.symbols - - -@dataclass -class DocumentSymbols(DocumentElement): - """Storing the symbols of a document.""" - - page_count: int - page_symbols: List[DocumentPageSymbols] = field(default_factory=list) - - # TODO[kylel] - this is more confusing than simply treating it as list[list], like it is == `docsyms[0][2:3]` - def __getitem__(self, indices): - page_id, symbol_slices = indices - assert page_id < len(self.page_symbols), "Page index out of range" - return self.page_symbols[page_id][symbol_slices] - - def to_json(self): - return [page_symbols.to_json() for page_symbols in self.page_symbols] - - @classmethod - def from_json(cls, symbols_dict: List[str]) -> "DocumentSymbols": - page_symbols = [DocumentPageSymbols(symbols=page_text) for page_text in symbols_dict] - return cls(page_count=len(page_symbols), page_symbols=page_symbols) diff --git a/src/mmda/types/old/image.old.py b/src/mmda/types/old/image.old.py deleted file mode 100644 index 76ba4a07..00000000 --- a/src/mmda/types/old/image.old.py +++ /dev/null @@ -1,32 +0,0 @@ -""" - -Dataclass for doing stuff on images of pages of a document - -@kylel, @shannons - -""" - -import base64 -from io import BytesIO - -from PIL import Image - -# Monkey patch the PIL.Image methods to add base64 conversion - -def tobase64(self): - # Ref: https://stackoverflow.com/a/31826470 - buffered = BytesIO() - self.save(buffered, format="PNG") - img_str = base64.b64encode(buffered.getvalue()) - - return img_str.decode("utf-8") - -def frombase64(img_str): - # Use the same naming style as the original Image methods - - buffered = BytesIO(base64.b64decode(img_str)) - img = Image.open(buffered) - return img - -Image.Image.tobase64 = tobase64 # This is the method applied to individual Image classes -Image.frombase64 = frombase64 # This is bind to the module, used for loading the images \ No newline at end of file diff --git a/src/mmda/types/old/span.old.py b/src/mmda/types/old/span.old.py deleted file mode 100644 index fe070cc4..00000000 --- a/src/mmda/types/old/span.old.py +++ /dev/null @@ -1,61 +0,0 @@ -""" - -Dataclass for doing stuff on token streams of a document - -@kylel - -""" - -from typing import List, Optional, Dict, Union - -from mmda.types.boundingbox import BoundingBox - -import json - - -class Span: - def __init__(self, start: int, end: int, id: Optional[int] = None, type: Optional[str] = None, - text: Optional[str] = None, bbox: Optional[BoundingBox] = None): - self.start = start - self.end = end - self.type = type - self.id = id - self.text = text - self.bbox = bbox - - @classmethod - def from_json(cls, span_json: Dict): - bbox = BoundingBox.from_json(bbox_json=span_json['bbox']) if 'bbox' in span_json else None - span = cls(start=span_json['start'], end=span_json['end'], id=span_json.get('id'), - type=span_json.get('type'), text=span_json.get('text'), bbox=bbox) - return span - - def to_json(self, exclude: List[str] = []) -> Dict: - full_json = {'start': self.start, - 'end': self.end, - 'type': self.type, - 'id': self.id, - 'text': self.text, - 'bbox': self.bbox.to_json() if self.bbox else None} - # the `is not None` is to save serialization space for empty fields - return {k: v for k, v in full_json.items() if k not in exclude and v is not None} - - def __repr__(self): - return json.dumps({k: v for k, v in self.to_json().items() if v is not None}) - - def __contains__(self, val: Union[int, BoundingBox]) -> bool: - """Checks whether an index value `i` is within the span""" - if isinstance(val, int): - return self.start <= val < self.end - elif isinstance(val, str): - return val in self.text - elif isinstance(val, BoundingBox): - raise NotImplementedError - else: - raise ValueError(f'{val} of type {type(val)} not supported for __contains__') - - def __lt__(self, other: 'Span'): - if self.id and other.id: - return self.id < other.id - else: - return self.start < other.start diff --git a/tests/test_internal_ai2/test_api.py b/tests/test_internal_ai2/test_api.py index 404f2940..ea1d6552 100644 --- a/tests/test_internal_ai2/test_api.py +++ b/tests/test_internal_ai2/test_api.py @@ -20,7 +20,7 @@ class ClassificationSpanGroup(mmda_api.SpanGroup): class TestApi(unittest.TestCase): def test_vanilla_span_group(self) -> None: - sg_ann = mmda_ann.SpanGroup.from_json({ + sg_ann = mmda_ann.Entity.from_json({ 'spans': [{'start': 0, 'end': 1}], 'id': 1, 'metadata': {'text': 'hello', 'id': 999} # note id not used; it's just in metadata @@ -33,7 +33,7 @@ def test_vanilla_span_group(self) -> None: self.assertEqual(sg_api.attributes.dict(), {}) def test_classification_span_group(self) -> None: - sg_ann = mmda_ann.SpanGroup.from_json({ + sg_ann = mmda_ann.Entity.from_json({ 'spans': [{'start': 0, 'end': 1}], 'metadata': {'text': 'hello', 'id': 1} }) @@ -63,7 +63,7 @@ def test_classification_span_group(self) -> None: ClassificationSpanGroup.from_mmda(sg_ann) def test_equivalence(self): - sg_ann = mmda_ann.SpanGroup.from_json({ + sg_ann = mmda_ann.Entity.from_json({ 'spans': [{'start': 0, 'end': 1}], 'metadata': {'label': 'label', 'score': 0.5} }) diff --git a/tests/test_parsers/test_override.py b/tests/test_parsers/test_override.py index 4b836be1..6027d622 100644 --- a/tests/test_parsers/test_override.py +++ b/tests/test_parsers/test_override.py @@ -4,7 +4,7 @@ from typing import List from mmda.types.document import Document -from mmda.types.annotation import SpanGroup +from mmda.types.annotation import Entity from mmda.types.names import TokensField from mmda.parsers.pdfplumber_parser import PDFPlumberParser from mmda.predictors.base_predictors.base_predictor import BasePredictor @@ -20,10 +20,10 @@ class MockPredictor(BasePredictor): REQUIRED_BACKENDS = [] # pyright: ignore REQUIRED_DOCUMENT_FIELDS = [] # pyright: ignore - def predict(self, document: Document) -> List[SpanGroup]: - token: SpanGroup + def predict(self, document: Document) -> List[Entity]: + token: Entity return [ - SpanGroup( + Entity( spans=token.spans, box_group=token.box_group, metadata=token.metadata, diff --git a/tests/test_parsers/test_pdf_plumber_parser.py b/tests/test_parsers/test_pdf_plumber_parser.py index a6f211df..e8c7dafa 100644 --- a/tests/test_parsers/test_pdf_plumber_parser.py +++ b/tests/test_parsers/test_pdf_plumber_parser.py @@ -7,7 +7,7 @@ import pathlib import unittest -from mmda.types import Document, SpanGroup, BoxGroup, Span, Box +from mmda.types import Document, Entity, BoxGroup, Span, Box from mmda.parsers import PDFPlumberParser import re @@ -126,13 +126,13 @@ def test_convert_nested_text_to_doc_json(self): page_ids=page_ids ) assert out['symbols'] == 'abc\nd ef\ngh i\njkl' - tokens = [SpanGroup.from_json(span_group_dict=t_dict) for t_dict in out['tokens']] + tokens = [Entity.from_json(entity_dict=t_dict) for t_dict in out['tokens']] assert [(t.start, t.end) for t in tokens] == [(0, 2), (2, 3), (4, 5), (6, 8), (9, 11), (12, 13), (14, 15), (15, 17)] assert [out['symbols'][t.start : t.end] for t in tokens] == ['ab', 'c', 'd', 'ef', 'gh', 'i', 'j', 'kl'] - rows = [SpanGroup.from_json(span_group_dict=r_dict) for r_dict in out['rows']] + rows = [Entity.from_json(entity_dict=r_dict) for r_dict in out['rows']] assert [(r.start, r.end) for r in rows] == [(0, 3), (4, 8), (9, 13), (14, 17)] assert [out['symbols'][r.start: r.end] for r in rows] == ['abc', 'd ef', 'gh i', 'jkl'] - pages = [SpanGroup.from_json(span_group_dict=p_dict) for p_dict in out['pages']] + pages = [Entity.from_json(entity_dict=p_dict) for p_dict in out['pages']] assert [(p.start, p.end) for p in pages] == [(0, 8), (9, 17)] assert [out['symbols'][p.start: p.end] for p in pages] == ['abc\nd ef', 'gh i\njkl'] diff --git a/tests/test_predictors/test_bibentry_predictor.py b/tests/test_predictors/test_bibentry_predictor.py index b125963c..d53d0879 100644 --- a/tests/test_predictors/test_bibentry_predictor.py +++ b/tests/test_predictors/test_bibentry_predictor.py @@ -5,13 +5,13 @@ StringWithSpan ) from mmda.predictors.hf_predictors.bibentry_predictor import utils -from mmda.types.annotation import SpanGroup +from mmda.types.annotation import Entity from mmda.types.span import Span class TestBibEntryPredictor(unittest.TestCase): def test__map_raw_predictions_to_mmda(self): - sg = SpanGroup(spans=[Span(start=17778, end=17832, box=None), Span(start=20057, end=20233, box=None)]) + sg = Entity(spans=[Span(start=17778, end=17832, box=None), Span(start=20057, end=20233, box=None)]) raw_prediction = BibEntryPredictionWithSpan( citation_number=StringWithSpan(content='10', start=0, end=2), authors=[ diff --git a/tests/test_predictors/test_dictionary_word_predictor.py b/tests/test_predictors/test_dictionary_word_predictor.py index 5bbce798..2ddb63b0 100644 --- a/tests/test_predictors/test_dictionary_word_predictor.py +++ b/tests/test_predictors/test_dictionary_word_predictor.py @@ -10,14 +10,14 @@ from mmda.predictors.heuristic_predictors.dictionary_word_predictor import Dictionary from mmda.predictors import DictionaryWordPredictor -from mmda.types import Document, SpanGroup, Span +from mmda.types import Document, Entity, Span -def mock_document(symbols: str, spans: List[Span], rows: List[SpanGroup]) -> Document: +def mock_document(symbols: str, spans: List[Span], rows: List[Entity]) -> Document: doc = Document(symbols=symbols) doc.annotate(rows=rows) doc.annotate( - tokens=[SpanGroup(id=i, spans=[span]) for i, span in enumerate(spans)] + tokens=[Entity(id=i, spans=[span]) for i, span in enumerate(spans)] ) return doc @@ -83,7 +83,7 @@ def test_hyphenated_word_combines(self): Span(start=79, end=80), # . ] - rows = [SpanGroup(id=0, spans=spans[0:15]), SpanGroup(id=1, spans=spans[15:])] + rows = [Entity(id=0, spans=spans[0:15]), Entity(id=1, spans=spans[15:])] document = mock_document(symbols=text, spans=spans, rows=rows) with tempfile.NamedTemporaryFile() as f: @@ -114,8 +114,8 @@ def test_next_row_single_token(self): ] rows = [ - SpanGroup(id=1, spans=spans[0:2]), - SpanGroup(id=2, spans=spans[2:3]), + Entity(id=1, spans=spans[0:2]), + Entity(id=2, spans=spans[2:3]), ] document = mock_document(symbols=text, spans=spans, rows=rows) @@ -140,9 +140,9 @@ def test_single_token_rows(self): Span(start=4, end=9) ] rows = [ - SpanGroup(id=1, spans=[spans[0]]), - SpanGroup(id=2, spans=[spans[1]]), - SpanGroup(id=3, spans=[spans[2]]), + Entity(id=1, spans=[spans[0]]), + Entity(id=2, spans=[spans[1]]), + Entity(id=3, spans=[spans[2]]), ] document = mock_document(symbols=text, spans=spans, rows=rows) words = predictor.predict(document) @@ -157,9 +157,9 @@ def test_single_token_rows(self): Span(start=4, end=9) ] rows = [ - SpanGroup(id=1, spans=[spans[0]]), - SpanGroup(id=2, spans=[spans[1]]), - SpanGroup(id=3, spans=[spans[2]]), + Entity(id=1, spans=[spans[0]]), + Entity(id=2, spans=[spans[1]]), + Entity(id=3, spans=[spans[2]]), ] document = mock_document(symbols=text, spans=spans, rows=rows) words = predictor.predict(document) @@ -178,9 +178,9 @@ def test_words_with_surrounding_punct(self): Span(start=4, end=5) ] rows = [ - SpanGroup(id=1, spans=[spans[0]]), - SpanGroup(id=2, spans=[spans[1]]), - SpanGroup(id=3, spans=[spans[2]]), + Entity(id=1, spans=[spans[0]]), + Entity(id=2, spans=[spans[1]]), + Entity(id=3, spans=[spans[2]]), ] document = mock_document(symbols=text, spans=spans, rows=rows) words = predictor.predict(document) @@ -198,9 +198,9 @@ def test_words_with_multiple_preceding_punct(self): Span(start=2, end=3) ] rows = [ - SpanGroup(id=1, spans=[spans[0]]), - SpanGroup(id=2, spans=[spans[1]]), - SpanGroup(id=3, spans=[spans[2]]), + Entity(id=1, spans=[spans[0]]), + Entity(id=2, spans=[spans[1]]), + Entity(id=3, spans=[spans[2]]), ] document = mock_document(symbols=text, spans=spans, rows=rows) words = predictor.predict(document) diff --git a/tests/test_predictors/test_span_group_classification_predictor.py b/tests/test_predictors/test_span_group_classification_predictor.py index 84efebff..8778339e 100644 --- a/tests/test_predictors/test_span_group_classification_predictor.py +++ b/tests/test_predictors/test_span_group_classification_predictor.py @@ -8,7 +8,7 @@ import json -from mmda.types.annotation import Span, SpanGroup +from mmda.types.annotation import Span, Entity from mmda.types.document import Document from mmda.parsers.pdfplumber_parser import PDFPlumberParser from mmda.predictors.hf_predictors.span_group_classification_predictor import ( @@ -24,8 +24,8 @@ class TestSpangroupClassificationPredictor(unittest.TestCase): def setUp(self): self.doc = Document.from_json(doc_dict=TEST_DOC_JSON) - sg1 = SpanGroup(spans=[Span(start=86, end=456)]) - sg2 = SpanGroup(spans=[Span(start=457, end=641)]) + sg1 = Entity(spans=[Span(start=86, end=456)]) + sg2 = Entity(spans=[Span(start=457, end=641)]) self.doc.annotate(bibs=[sg1, sg2]) self.predictor = SpanGroupClassificationPredictor.from_pretrained( diff --git a/tests/test_predictors/test_whitespace_predictor.py b/tests/test_predictors/test_whitespace_predictor.py index bfb99352..17f35621 100644 --- a/tests/test_predictors/test_whitespace_predictor.py +++ b/tests/test_predictors/test_whitespace_predictor.py @@ -8,7 +8,7 @@ import unittest from mmda.predictors import WhitespacePredictor -from mmda.types import Document, SpanGroup, Span +from mmda.types import Document, Entity, Span @@ -42,7 +42,7 @@ def test_predict(self): ] doc = Document(symbols=symbols) - doc.annotate(tokens=[SpanGroup(id=i, spans=[span]) for i, span in enumerate(spans)]) + doc.annotate(tokens=[Entity(id=i, spans=[span]) for i, span in enumerate(spans)]) predictor = WhitespacePredictor() ws_chunks = predictor.predict(doc) diff --git a/tests/test_types/test_indexers.py b/tests/test_types/test_indexers.py index 05ab6885..869b5ebf 100644 --- a/tests/test_types/test_indexers.py +++ b/tests/test_types/test_indexers.py @@ -1,13 +1,13 @@ import unittest -from mmda.types import SpanGroup, Span +from mmda.types import Entity, Span from mmda.types.indexers import SpanGroupIndexer class TestSpanGroupIndexer(unittest.TestCase): def test_overlap_within_single_spangroup_fails_checks(self): span_groups = [ - SpanGroup( + Entity( id=1, spans=[ Span(0, 5), @@ -21,14 +21,14 @@ def test_overlap_within_single_spangroup_fails_checks(self): def test_overlap_between_spangroups_fails_checks(self): span_groups = [ - SpanGroup( + Entity( id=1, spans=[ Span(0, 5), Span(5, 8) ] ), - SpanGroup( + Entity( id=2, spans=[Span(6, 10)] ) @@ -39,18 +39,18 @@ def test_overlap_between_spangroups_fails_checks(self): def test_finds_matching_groups_in_doc_order(self): span_groups_to_index = [ - SpanGroup( + Entity( id=1, spans=[ Span(0, 5), Span(5, 8) ] ), - SpanGroup( + Entity( id=2, spans=[Span(9, 10)] ), - SpanGroup( + Entity( id=3, spans=[Span(100, 105)] ) @@ -59,7 +59,7 @@ def test_finds_matching_groups_in_doc_order(self): index = SpanGroupIndexer(span_groups_to_index) # should intersect 1 and 2 but not 3 - probe = SpanGroup(id=3, spans=[Span(1, 7), Span(9, 20)]) + probe = Entity(id=3, spans=[Span(1, 7), Span(9, 20)]) matches = index.find(probe) self.assertEqual(len(matches), 2) diff --git a/tests/test_types/test_json_conversion.py b/tests/test_types/test_json_conversion.py index e7a5f27d..0f9f8a62 100644 --- a/tests/test_types/test_json_conversion.py +++ b/tests/test_types/test_json_conversion.py @@ -8,7 +8,7 @@ import json from pathlib import Path -from mmda.types import BoxGroup, SpanGroup, Document, Metadata +from mmda.types import BoxGroup, Entity, Document, Metadata from mmda.parsers import PDFPlumberParser @@ -16,8 +16,8 @@ def test_span_group_conversion(): - sg = SpanGroup(spans=[], id=3, metadata=Metadata.from_json({"text": "test"})) - sg2 = SpanGroup.from_json(sg.to_json()) + sg = Entity(spans=[], id=3, metadata=Metadata.from_json({"text": "test"})) + sg2 = Entity.from_json(sg.to_json()) assert sg2.to_json() == sg.to_json() assert sg2.__dict__ == sg.__dict__ @@ -51,8 +51,8 @@ def test_doc_conversion(): ) # type annotations to keep mypy quiet - orig_sg: SpanGroup - new_sg: SpanGroup + orig_sg: Entity + new_sg: Entity for orig_sg, new_sg in field_it: # for each pair, they should have same metadata (type, id, diff --git a/tests/test_types/test_span_group.py b/tests/test_types/test_span_group.py index 6068b5ee..60dd07c8 100644 --- a/tests/test_types/test_span_group.py +++ b/tests/test_types/test_span_group.py @@ -7,7 +7,7 @@ import json import unittest -from mmda.types import SpanGroup, Document, Span +from mmda.types import Entity, Document, Span class TestSpanGroup(unittest.TestCase): @@ -17,7 +17,7 @@ def setUp(self) -> None: self.doc = Document("This is a test document!") def test_annotation_attaches_document(self): - span_group = SpanGroup(id=1, spans=[Span(0, 4), Span(5, 7)]) + span_group = Entity(id=1, spans=[Span(0, 4), Span(5, 7)]) self.doc.annotate(tokens=[span_group]) span_group = self.doc.tokens[0]