diff --git a/recon/loaders.py b/recon/loaders.py index f3ca606..a957e84 100644 --- a/recon/loaders.py +++ b/recon/loaders.py @@ -8,7 +8,7 @@ import spacy import srsly from spacy.language import Language -from spacy.tokens import Doc, DocBin +from spacy.tokens import Doc, DocBin, Span as SpacySpan from spacy.util import get_words_and_spaces from recon.types import Example, Span, Token @@ -120,9 +120,10 @@ def to_spacy( tokens = [token.text for token in example.tokens] words, spaces = get_words_and_spaces(tokens, example.text) doc = Doc(nlp.vocab, words=words, spaces=spaces) - doc.set_ents( - [doc.char_span(s.start, s.end, label=s.label) for s in example.spans] - ) + spacy_spans = [ + doc.char_span(s.start, s.end, label=s.label) for s in example.spans + ] + doc.set_ents(cast(List[SpacySpan], spacy_spans)) doc_bin.add(doc) doc_bin.to_disk(path) return doc_bin diff --git a/recon/types.py b/recon/types.py index 8318f4c..11a1fb6 100644 --- a/recon/types.py +++ b/recon/types.py @@ -13,12 +13,13 @@ Protocol, Tuple, Union, + cast, ) from typing_extensions import ParamSpec from pydantic import BaseModel, field_validator, model_validator from spacy import displacy -from spacy.tokens import Doc +from spacy.tokens import Doc, Span as SpacySpan from spacy.util import get_words_and_spaces from spacy.vocab import Vocab from wasabi import color @@ -130,7 +131,9 @@ def doc(self) -> Doc: tokens = [token.text for token in self.tokens] words, spaces = get_words_and_spaces(tokens, self.text) doc = Doc(Vocab(), words=words, spaces=spaces) - doc.set_ents([doc.char_span(s.start, s.end, label=s.label) for s in self.spans]) + spans = [doc.char_span(s.start, s.end, label=s.label) for s in self.spans] + + doc.set_ents(cast(List[SpacySpan], spans)) return doc def show( @@ -369,9 +372,10 @@ def show(self, label_suffix: str = "PRED"): tokens = [token.text for token in combined.tokens] words, spaces = get_words_and_spaces(tokens, combined.text) doc = Doc(Vocab(), words=words, spaces=spaces) - doc.spans["ref"] = [ + ref_spans = [ doc.char_span(s.start, s.end, label=s.label) for s in combined.spans ] + doc.spans["ref"] = cast(List[SpacySpan], ref_spans) displacy.render(doc, style="span", jupyter=True, options={"spans_key": "ref"})