Skip to content

Commit

Permalink
feat: add context_getter argument to eds.matcher
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed May 18, 2024
1 parent d849de7 commit cb2c32c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
- `edsnlp.load` now accepts EDS-NLP models from the huggingface hub 🤗 !
- New `python -m edsnlp.package` command to package a model for the huggingface hub or pypi-like registries
- Expose the defaults patterns of `eds.negation`, `eds.hypothesis`, `eds.family`, `eds.history` and `eds.reported_speech` under a `eds.negation.default_patterns` attribute
- Added a `context_getter` SpanGetter argument to the `eds.matcher` class to only retrieve entities inside the spans returned by the getter

### Changed

Expand Down
17 changes: 13 additions & 4 deletions edsnlp/pipes/core/matcher/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from edsnlp.matchers.simstring import SimstringMatcher
from edsnlp.matchers.utils import Patterns
from edsnlp.pipes.base import BaseNERComponent, SpanSetterArg
from edsnlp.utils.span_getters import SpanGetterArg, get_spans


class GenericMatcher(BaseNERComponent):
Expand Down Expand Up @@ -102,6 +103,7 @@ def __init__(
term_matcher: Literal["exact", "simstring"] = "exact",
term_matcher_config: Dict[str, Any] = {},
span_setter: SpanSetterArg = {"ents": True},
context_getter: Optional[SpanGetterArg] = None,
):
super().__init__(nlp=nlp, name=name, span_setter=span_setter)

Expand All @@ -114,6 +116,7 @@ def __init__(
regex = regex or {}

self.attr = attr
self.context_getter = context_getter

if term_matcher == "exact":
self.phrase_matcher = EDSPhraseMatcher(
Expand Down Expand Up @@ -163,10 +166,16 @@ def process(self, doc: Doc) -> List[Span]:
List of Spans returned by the matchers.
"""

matches = self.phrase_matcher(doc, as_spans=True)
regex_matches = self.regex_matcher(doc, as_spans=True)

spans = list(matches) + list(regex_matches)
contexts = (
list(get_spans(doc, self.context_getter))
if self.context_getter is not None
else [doc]
)
spans: List[Span] = []
for context in contexts:
matches = self.phrase_matcher(context, as_spans=True)
regex_matches = self.regex_matcher(context, as_spans=True)
spans.extend(list(matches) + list(regex_matches))

return spans

Expand Down

0 comments on commit cb2c32c

Please sign in to comment.