Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge boxes from citation mention model #273

Merged
merged 7 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = 'mmda'
version = '0.9.9'
version = '0.9.10'
description = 'MMDA - multimodal document analysis'
authors = [
{name = 'Allen Institute for Artificial Intelligence', email = '[email protected]'},
Expand Down
36 changes: 36 additions & 0 deletions src/ai2_internal/citation_mentions/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
"""

from typing import List
from itertools import groupby

from pydantic import BaseModel, BaseSettings

from ai2_internal import api
from mmda.predictors.hf_predictors.mention_predictor import MentionPredictor
from mmda.types.document import Document
from mmda.types.box import Box
from mmda.types.annotation import BoxGroup as MMDABoxGroup


class Instance(BaseModel):
Expand All @@ -39,6 +42,30 @@ class PredictorConfig(BaseSettings):
pass


def group_by_line(boxes):
boxes = sorted(boxes, key=lambda box: box.t)
return [list(line_boxes) for t, line_boxes in groupby(boxes, key=lambda box: box.t)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you seeing that the line grouping is working well? I think it could be better if we had rows passed into this model, so that we can take advantage of pdfplumber's grouping by lines which I think is more robust than requiring exact t coordinates matching up (see

def _simple_line_detection(
) -- or maybe just box.t +/- a little amount calculated from getting average token heights? (see
return np.average([[span.box.w, span.box.h] for token in tokens
) -- just options, but if you're seeing good results with this grouping I wouldn't worry about making this more complicated

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it actually seems to work well! if the lines dont have the same t then I think it's safest to just not do anything.


def calc_bounding_box(boxes):
l = min(b.l for b in boxes)
t = min(b.t for b in boxes)
w = max(b.l + b.w for b in boxes) - l
h = max(b.t + b.h for b in boxes) - t
return Box(l=l, t=t, w=w, h=h, page=boxes[0].page)

def merge_boxes(boxes):
boxes_by_line = group_by_line(boxes)
return [calc_bounding_box(line_boxes) for line_boxes in boxes_by_line]

def all_spans_close(sg):
spans = sorted(sg.spans, key=lambda span: span.start)
return all(span.end <= next_span.start <= span.end + 5 for span, next_span in zip(spans, spans[1:]))

def build_box_group(sg):
boxes = [span.box for span in sg.spans]
return MMDABoxGroup(boxes=boxes)


class Predictor:
"""
Interface on to your underlying model.
Expand All @@ -64,6 +91,15 @@ def predict_one(self, inst: Instance) -> Prediction:
doc.annotate(pages=[sg.to_mmda() for sg in inst.pages])

prediction_span_groups = self._predictor.predict(doc)
box_groups = [build_box_group(sg) for sg in prediction_span_groups]
# set box_groups and delete span boxes
for sg, bg in zip(prediction_span_groups, box_groups):
sg.box_group = bg
for span in sg.spans:
span.box = None
for sg in prediction_span_groups:
if all_spans_close(sg):
sg.box_group.boxes = merge_boxes(sg.box_group.boxes)
doc.annotate(citation_mentions=prediction_span_groups)

return Prediction(mentions=[api.SpanGroup.from_mmda(sg) for sg in doc.citation_mentions])
Expand Down
Loading