Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] Kylel/2022 12/v2 release #188

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions examples/bibliography_extraction/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
from mmda.predictors.hf_predictors.vila_predictor import HVILAPredictor
from mmda.predictors.tesseract_predictors import TesseractBlockPredictor
from mmda.rasterizers.rasterizer import PDF2ImageRasterizer
from mmda.types.annotation import BoxGroup, SpanGroup
from mmda.types.annotation import BoxGroup, Entity
from mmda.types.document import Document

PDF_PATH = "resources/maml.pdf"


def _clone_span_group(span_group: SpanGroup):
return SpanGroup(
def _clone_span_group(span_group: Entity):
return Entity(
spans=span_group.spans,
id=span_group.id,
text=span_group.text,
Expand Down Expand Up @@ -46,7 +46,7 @@ def _index_document_pages(document) -> List[PageSpan]:
return page_spans


def _find_page_num(span_group: SpanGroup, page_spans: List[PageSpan]) -> int:
def _find_page_num(span_group: Entity, page_spans: List[PageSpan]) -> int:
s = min([span.start for span in span_group.spans])
e = max([span.end for span in span_group.spans])

Expand All @@ -58,7 +58,7 @@ def _find_page_num(span_group: SpanGroup, page_spans: List[PageSpan]) -> int:


def _highest_overlap_block(
token: SpanGroup, blocks: Iterable[BoxGroup]
token: Entity, blocks: Iterable[BoxGroup]
) -> Optional[BoxGroup]:
assert len(token.spans) == 1
token_box = token.spans[0].box
Expand All @@ -78,7 +78,7 @@ def _highest_overlap_block(
return found_block


def extract_bibliography_grotoap2(document: Document) -> Iterable[SpanGroup]:
def extract_bibliography_grotoap2(document: Document) -> Iterable[Entity]:
"""GROTOAP2 has type 1 for REFERENCES"""
return [_clone_span_group(sg) for sg in document.preds if sg.type == 1]

Expand Down
6 changes: 3 additions & 3 deletions examples/vila_for_scidoc_parsing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from mmda.parsers.pdfplumber_parser import PDFPlumberParser
from mmda.rasterizers.rasterizer import PDF2ImageRasterizer
from mmda.types.annotation import SpanGroup
from mmda.types.annotation import Entity
from mmda.predictors.lp_predictors import LayoutParserPredictor
from mmda.predictors.hf_predictors.vila_predictor import IVILAPredictor, HVILAPredictor

Expand All @@ -33,7 +33,7 @@

def draw_tokens(
image,
doc_tokens: List[SpanGroup],
doc_tokens: List[Entity],
color_map=None,
token_boundary_width=0,
alpha=0.25,
Expand Down Expand Up @@ -65,7 +65,7 @@ def draw_tokens(

def draw_blocks(
image,
doc_tokens: List[SpanGroup],
doc_tokens: List[Entity],
color_map=None,
token_boundary_width=0,
alpha=0.25,
Expand Down
8 changes: 4 additions & 4 deletions examples/vlue_evaluation/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
IVILAPredictor)
from mmda.predictors.lp_predictors import LayoutParserPredictor
from mmda.rasterizers.rasterizer import PDF2ImageRasterizer
from mmda.types.annotation import SpanGroup
from mmda.types.annotation import Entity
from mmda.types.document import Document


Expand All @@ -31,7 +31,7 @@ class VluePrediction:


def _vila_docbank_extract_entities(types: List[str]):
def extractor(doc: Document) -> Dict[str, List[SpanGroup]]:
def extractor(doc: Document) -> Dict[str, List[Entity]]:
mapping = {
"paragraph": 0,
"title": 1,
Expand Down Expand Up @@ -62,7 +62,7 @@ def extractor(doc: Document) -> Dict[str, List[SpanGroup]]:


def _vila_grotoap2_extract_entities(types: List[str]):
def extractor(doc: Document) -> Dict[str, List[SpanGroup]]:
def extractor(doc: Document) -> Dict[str, List[Entity]]:
# TODO: Have some sort of unified mapping between this and docbank
# TODO: Below title and abstract have been lower-cased to match docbank
mapping = {
Expand Down Expand Up @@ -107,7 +107,7 @@ def vila_prediction(
id_: str,
doc: Document,
vila_predictor: BaseVILAPredictor, # pylint: disable=redefined-outer-name
vila_extractor: Callable[[Document], Dict[str, List[SpanGroup]]],
vila_extractor: Callable[[Document], Dict[str, List[Entity]]],
) -> VluePrediction:
# Predict token types
span_groups = vila_predictor.predict(doc)
Expand Down
6 changes: 3 additions & 3 deletions src/ai2_internal/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class SpanGroup(Annotation):
text: Optional[str]

@classmethod
def from_mmda(cls, span_group: mmda_ann.SpanGroup) -> "SpanGroup":
def from_mmda(cls, span_group: mmda_ann.Entity) -> "SpanGroup":
box_group = (
BoxGroup.from_mmda(span_group.box_group)
if span_group.box_group is not None
Expand All @@ -136,13 +136,13 @@ def from_mmda(cls, span_group: mmda_ann.SpanGroup) -> "SpanGroup":
ret.box_group = BoxGroup.from_mmda(span_group.box_group)
return ret

def to_mmda(self) -> mmda_ann.SpanGroup:
def to_mmda(self) -> mmda_ann.Entity:
metadata = mmda_ann.Metadata.from_json(self.attributes.dict())
if self.type:
metadata.type = self.type
if self.text:
metadata.text = self.text
return mmda_ann.SpanGroup(
return mmda_ann.Entity(
metadata=metadata,
spans=[span.to_mmda() for span in self.spans],
box_group=self.box_group.to_mmda()if self.box_group else None,
Expand Down
4 changes: 2 additions & 2 deletions src/ai2_internal/vila/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from mmda.predictors.hf_predictors.token_classification_predictor import (
IVILATokenClassificationPredictor,
)
from mmda.types.document import Document, SpanGroup
from mmda.types.document import Document, Entity
from mmda.types.image import frombase64

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -54,7 +54,7 @@ class Prediction(BaseModel):
groups: List[api.SpanGroup]

@classmethod
def from_mmda(cls, groups: List[SpanGroup]) -> "Prediction":
def from_mmda(cls, groups: List[Entity]) -> "Prediction":
return cls(groups=[api.SpanGroup.from_mmda(grp) for grp in groups])


Expand Down
4 changes: 2 additions & 2 deletions src/mmda/featurizers/citation_link_featurizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from thefuzz import fuzz
from typing import List, Tuple, Dict

from mmda.types.annotation import SpanGroup
from mmda.types.annotation import Entity


DIGITS = re.compile(r'[0-9]+')
Expand All @@ -24,7 +24,7 @@
MATCH_FIRST_TOKEN = "match_first_token"

class CitationLink:
def __init__(self, mention: SpanGroup, bib: SpanGroup):
def __init__(self, mention: Entity, bib: Entity):
self.mention = mention
self.bib = bib

Expand Down
14 changes: 7 additions & 7 deletions src/mmda/parsers/grobid_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import json

from mmda.parsers.parser import Parser
from mmda.types.annotation import SpanGroup
from mmda.types.annotation import Entity
from mmda.types.document import Document
from mmda.types.metadata import Metadata
from mmda.types.span import Span
Expand All @@ -22,8 +22,8 @@
NS = {"tei": "http://www.tei-c.org/ns/1.0"}


def _null_span_group() -> SpanGroup:
sg = SpanGroup(spans=[])
def _null_span_group() -> Entity:
sg = Entity(spans=[])
return sg


Expand Down Expand Up @@ -97,7 +97,7 @@ def _parse_xml_to_doc(self, xml: str) -> Document:

return document

def _get_title(self, root: et.Element) -> SpanGroup:
def _get_title(self, root: et.Element) -> Entity:
matches = root.findall(".//tei:titleStmt/tei:title", NS)

if len(matches) == 0:
Expand All @@ -110,10 +110,10 @@ def _get_title(self, root: et.Element) -> SpanGroup:
tokens = text.split()
spans = _get_token_spans(text, tokens)

sg = SpanGroup(spans=spans, metadata=Metadata(text=text))
sg = Entity(spans=spans, metadata=Metadata(text=text))
return sg

def _get_abstract(self, root: et.Element, offset: int) -> SpanGroup:
def _get_abstract(self, root: et.Element, offset: int) -> Entity:
matches = root.findall(".//tei:profileDesc//tei:abstract//", NS)

if len(matches) == 0:
Expand All @@ -124,5 +124,5 @@ def _get_abstract(self, root: et.Element, offset: int) -> SpanGroup:
tokens = text.split()
spans = _get_token_spans(text, tokens, offset=offset)

sg = SpanGroup(spans=spans, metadata=Metadata(text=text))
sg = Entity(spans=spans, metadata=Metadata(text=text))
return sg
20 changes: 10 additions & 10 deletions src/mmda/parsers/pdfplumber_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import itertools
from mmda.types.span import Span
from mmda.types.box import Box
from mmda.types.annotation import SpanGroup
from mmda.types.annotation import Entity
from mmda.types.document import Document
from mmda.parsers.parser import Parser
from mmda.types.names import PagesField, RowsField, SymbolsField, TokensField
Expand Down Expand Up @@ -195,7 +195,7 @@ def _convert_nested_text_to_doc_json(

# 1) build tokens & symbols
symbols = ""
token_annos: List[SpanGroup] = []
token_annos: List[Entity] = []
start = 0
for token_id in range(len(token_dicts) - 1):

Expand All @@ -210,8 +210,8 @@ def _convert_nested_text_to_doc_json(

# 2) make Token
end = start + len(token_dict["text"])
token = SpanGroup(spans=[Span(start=start, end=end, box=token_dict["bbox"])],
id=token_id)
token = Entity(spans=[Span(start=start, end=end, box=token_dict["bbox"])],
id=token_id)
token_annos.append(token)

# 3) increment whitespace based on Row & Word membership. and build Rows.
Expand All @@ -228,19 +228,19 @@ def _convert_nested_text_to_doc_json(
# handle last token
symbols += token_dicts[-1]["text"]
end = start + len(token_dicts[-1]["text"])
token = SpanGroup(spans=[Span(start=start, end=end, box=token_dicts[-1]["bbox"])],
id=len(token_dicts) - 1)
token = Entity(spans=[Span(start=start, end=end, box=token_dicts[-1]["bbox"])],
id=len(token_dicts) - 1)
token_annos.append(token)

# 2) build rows
tokens_with_group_ids = [
(token, row_id, page_id)
for token, row_id, page_id in zip(token_annos, row_ids, page_ids)
]
row_annos: List[SpanGroup] = []
row_annos: List[Entity] = []
for row_id, tups in itertools.groupby(iterable=tokens_with_group_ids, key=lambda tup: tup[1]):
row_tokens = [token for token, _, _ in tups]
row = SpanGroup(
row = Entity(
spans=[
Span(
start=row_tokens[0][0].start,
Expand All @@ -255,10 +255,10 @@ def _convert_nested_text_to_doc_json(
row_annos.append(row)

# 3) build pages
page_annos: List[SpanGroup] = []
page_annos: List[Entity] = []
for page_id, tups in itertools.groupby(iterable=tokens_with_group_ids, key=lambda tup: tup[2]):
page_tokens = [token for token, _, _ in tups]
page = SpanGroup(
page = Entity(
spans=[
Span(
start=page_tokens[0][0].start,
Expand Down
18 changes: 9 additions & 9 deletions src/mmda/parsers/symbol_scraper_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from mmda.types.span import Span
from mmda.types.box import Box
from mmda.types.annotation import SpanGroup
from mmda.types.annotation import Entity
from mmda.types.document import Document
from mmda.parsers.parser import Parser
from mmda.types.names import PagesField, RowsField, SymbolsField, TokensField
Expand Down Expand Up @@ -192,15 +192,15 @@ def _parse_page_to_row_to_tokens(self, xml_lines: List[str], page_to_metrics: Di

def _convert_nested_text_to_doc_json(self, page_to_row_to_tokens: Dict) -> Dict:
text = ''
page_annos: List[SpanGroup] = []
token_annos: List[SpanGroup] = []
row_annos: List[SpanGroup] = []
page_annos: List[Entity] = []
token_annos: List[Entity] = []
row_annos: List[Entity] = []
start = 0
for page, row_to_tokens in page_to_row_to_tokens.items():
page_rows: List[SpanGroup] = []
page_rows: List[Entity] = []
for row, tokens in row_to_tokens.items():
# process tokens in this row
row_tokens: List[SpanGroup] = []
row_tokens: List[Entity] = []
for k, token in enumerate(tokens):
# TODO: this is a graphical token specific to SScraper. We process it here,
# instead of in XML so we can still reuse XML cache. But this should be replaced w/ better
Expand All @@ -210,7 +210,7 @@ def _convert_nested_text_to_doc_json(self, page_to_row_to_tokens: Dict) -> Dict:
text += token['text']
end = start + len(token['text'])
# make token
token = SpanGroup(spans=[Span(start=start, end=end, box=token['bbox'])])
token = Entity(spans=[Span(start=start, end=end, box=token['bbox'])])
row_tokens.append(token)
token_annos.append(token)
if k < len(tokens) - 1:
Expand All @@ -219,14 +219,14 @@ def _convert_nested_text_to_doc_json(self, page_to_row_to_tokens: Dict) -> Dict:
text += '\n' # start newline at end of row
start = end + 1
# make row
row = SpanGroup(spans=[
row = Entity(spans=[
Span(start=row_tokens[0][0].start, end=row_tokens[-1][0].end,
box=Box.small_boxes_to_big_box(boxes=[span.box for t in row_tokens for span in t]))
])
page_rows.append(row)
row_annos.append(row)
# make page
page = SpanGroup(spans=[
page = Entity(spans=[
Span(start=page_rows[0][0].start, end=page_rows[-1][0].end,
box=Box.small_boxes_to_big_box(boxes=[span.box for r in page_rows for span in r]))
])
Expand Down
Loading