-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding relations #158
base: main
Are you sure you want to change the base?
adding relations #158
Changes from all commits
d9a64ac
b17bbb7
3132290
055dfea
813e134
affc3e8
87de5c9
92b79d4
70bc6bf
34d6fc6
32a2e11
d763d09
98da4ea
9d09af7
e827f02
8490804
2577c7d
e5d1d28
f83eb00
84682ee
b93462c
24f0c39
818349b
5820934
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -5,7 +5,10 @@ | |||||
Collections of Annotations are how one constructs a new | ||||||
Iterable of Group-type objects within the Document | ||||||
|
||||||
@kylel, @lucas | ||||||
|
||||||
""" | ||||||
import logging | ||||||
import warnings | ||||||
from abc import abstractmethod | ||||||
from copy import deepcopy | ||||||
|
@@ -18,11 +21,9 @@ | |||||
if TYPE_CHECKING: | ||||||
from mmda.types.document import Document | ||||||
|
||||||
|
||||||
__all__ = ["Annotation", "BoxGroup", "SpanGroup", "Relation"] | ||||||
|
||||||
|
||||||
|
||||||
def warn_deepcopy_of_annotation(obj: "Annotation") -> None: | ||||||
"""Warns when a deepcopy is performed on an Annotation.""" | ||||||
|
||||||
|
@@ -34,6 +35,22 @@ def warn_deepcopy_of_annotation(obj: "Annotation") -> None: | |||||
warnings.warn(msg, UserWarning, stacklevel=2) | ||||||
|
||||||
|
||||||
class AnnotationName: | ||||||
"""Stores a name that uniquely identifies this Annotation within a Document""" | ||||||
|
||||||
def __init__(self, field: str, id: int): | ||||||
self.field = field | ||||||
self.id = id | ||||||
|
||||||
def __str__(self) -> str: | ||||||
return f"{self.field}-{self.id}" | ||||||
|
||||||
@classmethod | ||||||
def from_str(cls, s: str) -> 'AnnotationName': | ||||||
field, id = s.split('-') | ||||||
id = int(id) | ||||||
return AnnotationName(field=field, id=id) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Prevents issues when inheriting (if we ever decide to). |
||||||
|
||||||
|
||||||
class Annotation: | ||||||
"""Annotation is intended for storing model predictions for a document.""" | ||||||
|
@@ -42,40 +59,55 @@ def __init__( | |||||
self, | ||||||
id: Optional[int] = None, | ||||||
doc: Optional['Document'] = None, | ||||||
field: Optional[str] = None, | ||||||
metadata: Optional[Metadata] = None | ||||||
): | ||||||
self.id = id | ||||||
self.doc = doc | ||||||
self.field = field | ||||||
self.metadata = metadata if metadata else Metadata() | ||||||
|
||||||
@abstractmethod | ||||||
def to_json(self) -> Dict: | ||||||
pass | ||||||
raise NotImplementedError | ||||||
|
||||||
@classmethod | ||||||
@abstractmethod | ||||||
def from_json(cls, annotation_dict: Dict) -> "Annotation": | ||||||
pass | ||||||
raise NotImplementedError | ||||||
|
||||||
@property | ||||||
def name(self) -> Optional[AnnotationName]: | ||||||
if self.field and self.id: | ||||||
return AnnotationName(field=self.field, id=self.id) | ||||||
else: | ||||||
return None | ||||||
|
||||||
def attach_doc(self, doc: "Document") -> None: | ||||||
def _attach_doc(self, doc: "Document", field: str) -> None: | ||||||
if not self.doc: | ||||||
self.doc = doc | ||||||
self.field = field | ||||||
else: | ||||||
raise AttributeError("This annotation already has an attached document") | ||||||
|
||||||
# TODO[kylel] - comment explaining | ||||||
def __getattr__(self, field: str) -> List["Annotation"]: | ||||||
if self.doc is None: | ||||||
raise ValueError("This annotation is not attached to a document") | ||||||
def _get_siblings(self) -> List['Annotation']: | ||||||
"""This method gets all other objects sharing the same field as the current object. | ||||||
Only works after a Document has been attached, which is how objects learn their `field`.""" | ||||||
if not self.doc: | ||||||
raise AttributeError("This annotation does not have an attached document") | ||||||
return self.doc.__getattr__(self.field) | ||||||
|
||||||
if field in self.doc.fields: | ||||||
return self.doc.find_overlapping(self, field) | ||||||
def __getattr__(self, field: str) -> List["Annotation"]: | ||||||
"""This method allows jumping from an object of one field to all overlapping | ||||||
objects of another field. For example `page.tokens` jumps from a particular page | ||||||
to all its intersecting tokens.""" | ||||||
if not self.doc: | ||||||
raise AttributeError("This annotation does not have an attached document") | ||||||
|
||||||
if field in self.doc.fields: | ||||||
return self.doc.find_overlapping(self, field) | ||||||
|
||||||
return self.__getattribute__(field) | ||||||
|
||||||
else: | ||||||
return [] | ||||||
|
||||||
|
||||||
class BoxGroup(Annotation): | ||||||
|
@@ -84,12 +116,14 @@ def __init__( | |||||
boxes: List[Box], | ||||||
id: Optional[int] = None, | ||||||
doc: Optional['Document'] = None, | ||||||
field: Optional[str] = None, | ||||||
metadata: Optional[Metadata] = None, | ||||||
Comment on lines
116
to
120
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not related to this PR specifically, but I think we should document these arguments in a docstring. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yea we'll need that |
||||||
): | ||||||
self.boxes = boxes | ||||||
super().__init__(id=id, doc=doc, metadata=metadata) | ||||||
super().__init__(id=id, doc=doc, field=field, metadata=metadata) | ||||||
|
||||||
def to_json(self) -> Dict: | ||||||
"""Note: even if `doc` or `field` are attached, don't include in JSON to avoid bloat""" | ||||||
box_group_dict = dict( | ||||||
boxes=[box.to_json() for box in self.boxes], | ||||||
id=self.id, | ||||||
|
@@ -132,6 +166,7 @@ def __deepcopy__(self, memo): | |||||
box_group = BoxGroup( | ||||||
boxes=deepcopy(self.boxes, memo), | ||||||
id=self.id, | ||||||
field=self.field, | ||||||
metadata=deepcopy(self.metadata, memo) | ||||||
) | ||||||
|
||||||
|
@@ -142,47 +177,38 @@ def __deepcopy__(self, memo): | |||||
|
||||||
@property | ||||||
def type(self) -> str: | ||||||
logging.warning(msg='`.type` to be deprecated in future versions. Use `.metadata.type`') | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: would not use the root logger. I would add Logger = logger.getLogger(__file__) after imports, and then call |
||||||
return self.metadata.get("type", None) | ||||||
|
||||||
@type.setter | ||||||
def type(self, type: Union[str, None]) -> None: | ||||||
logging.warning(msg='`.type` to be deprecated in future versions. Use `.metadata.type`') | ||||||
self.metadata.type = type | ||||||
|
||||||
|
||||||
class SpanGroup(Annotation): | ||||||
|
||||||
def __init__( | ||||||
self, | ||||||
spans: List[Span], | ||||||
box_group: Optional[BoxGroup] = None, | ||||||
id: Optional[int] = None, | ||||||
doc: Optional['Document'] = None, | ||||||
field: Optional[str] = None, | ||||||
metadata: Optional[Metadata] = None, | ||||||
): | ||||||
self.spans = spans | ||||||
self.box_group = box_group | ||||||
super().__init__(id=id, doc=doc, metadata=metadata) | ||||||
super().__init__(id=id, doc=doc, field=field, metadata=metadata) | ||||||
|
||||||
@property | ||||||
def symbols(self) -> List[str]: | ||||||
if self.doc is not None: | ||||||
return [ | ||||||
self.doc.symbols[span.start: span.end] for span in self.spans | ||||||
] | ||||||
return [self.doc.symbols[span.start: span.end] for span in self.spans] | ||||||
else: | ||||||
return [] | ||||||
|
||||||
def annotate( | ||||||
self, is_overwrite: bool = False, **kwargs: Iterable["Annotation"] | ||||||
) -> None: | ||||||
if self.doc is None: | ||||||
raise ValueError("SpanGroup has no attached document!") | ||||||
|
||||||
key_remaps = {k: v for k, v in kwargs.items()} | ||||||
|
||||||
self.doc.annotate(is_overwrite=is_overwrite, **key_remaps) | ||||||
|
||||||
def to_json(self) -> Dict: | ||||||
"""Note: even if `doc` or `field` are attached, don't include in JSON to avoid bloat""" | ||||||
span_group_dict = dict( | ||||||
spans=[span.to_json() for span in self.spans], | ||||||
id=self.id, | ||||||
|
@@ -208,7 +234,7 @@ def from_json(cls, span_group_dict: Dict) -> "SpanGroup": | |||||
else: | ||||||
# this fallback is necessary to ensure compatibility with span | ||||||
# groups that were create before the metadata migration and | ||||||
# therefore have "id", "type" in the root of the json dict instead. | ||||||
# therefore have "type" in the root of the json dict instead. | ||||||
metadata_dict = { | ||||||
"type": span_group_dict.get("type", None), | ||||||
"text": span_group_dict.get("text", None) | ||||||
|
@@ -255,6 +281,7 @@ def __deepcopy__(self, memo): | |||||
span_group = SpanGroup( | ||||||
spans=deepcopy(self.spans, memo), | ||||||
id=self.id, | ||||||
field=self.field, | ||||||
metadata=deepcopy(self.metadata, memo), | ||||||
box_group=deepcopy(self.box_group, memo) | ||||||
) | ||||||
|
@@ -266,10 +293,12 @@ def __deepcopy__(self, memo): | |||||
|
||||||
@property | ||||||
def type(self) -> str: | ||||||
logging.warning(msg='`.type` to be deprecated in future versions. Use `.metadata.type`') | ||||||
return self.metadata.get("type", None) | ||||||
|
||||||
@type.setter | ||||||
def type(self, type: Union[str, None]) -> None: | ||||||
logging.warning(msg='`.type` to be deprecated in future versions. Use `.metadata.type`') | ||||||
self.metadata.type = type | ||||||
|
||||||
@property | ||||||
|
@@ -284,6 +313,49 @@ def text(self, text: Union[str, None]) -> None: | |||||
self.metadata.text = text | ||||||
|
||||||
|
||||||
|
||||||
class Relation(Annotation): | ||||||
pass | ||||||
def __init__( | ||||||
self, | ||||||
key: SpanGroup, | ||||||
value: SpanGroup, | ||||||
id: Optional[int] = None, | ||||||
doc: Optional['Document'] = None, | ||||||
field: Optional[str] = None, | ||||||
metadata: Optional[Metadata] = None | ||||||
): | ||||||
if key.name is None: | ||||||
raise ValueError(f'Relation requires the key {key} to have a `.name`') | ||||||
if value.name is None: | ||||||
raise ValueError(f'Relation requires the value {value} to have a `.name`') | ||||||
self.key = key | ||||||
self.value = value | ||||||
super().__init__(id=id, doc=doc, field=field, metadata=metadata) | ||||||
|
||||||
def to_json(self) -> Dict: | ||||||
"""Note: even if `doc` or `field` are attached, don't include in JSON to avoid bloat""" | ||||||
relation_dict = dict( | ||||||
key=str(self.key.name), | ||||||
value=str(self.value.name), | ||||||
id=self.id, | ||||||
metadata=self.metadata.to_json() | ||||||
) | ||||||
return { | ||||||
key: value | ||||||
for key, value in relation_dict.items() | ||||||
if value is not None | ||||||
} # only serialize non-null values | ||||||
|
||||||
@classmethod | ||||||
def from_json( | ||||||
cls, | ||||||
relation_dict: Dict, | ||||||
doc: 'Document', | ||||||
) -> "Relation": | ||||||
key_name = AnnotationName.from_str(s=relation_dict['key']) | ||||||
value_name = AnnotationName.from_str(s=relation_dict['value']) | ||||||
return cls( | ||||||
key=doc.locate_annotation(name=key_name), | ||||||
value=doc.locate_annotation(name=value_name), | ||||||
id=relation_dict.get("id", None), | ||||||
metadata=Metadata.from_json(relation_dict.get('metadata', {})) | ||||||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Speeds up class creation by roughly 20%: