diff --git a/changelog.md b/changelog.md
index d16fc98b7..0b7bc1c47 100644
--- a/changelog.md
+++ b/changelog.md
@@ -14,6 +14,7 @@
- Added a `context_getter` SpanGetter argument to the `eds.matcher` class to only retrieve entities inside the spans returned by the getter
- Added a `filter_expr` parameter to scorers to filter the documents to score
- Added a new `required` field to `eds.contextual_matcher` assign patterns to only match if the required field has been found, and an `include` parameter (similar to `exclude`) to search for required patterns without assigning them to the entity
+- Added context strings (e.g., "words[0:5] | sent[0:1]") to the `eds.contextual_matcher` component to allow for more complex patterns in the selection of the window around the trigger spans
### Changed
diff --git a/edsnlp/pipes/core/contextual_matcher/contextual_matcher.py b/edsnlp/pipes/core/contextual_matcher/contextual_matcher.py
index 124933f28..2d7a56f71 100644
--- a/edsnlp/pipes/core/contextual_matcher/contextual_matcher.py
+++ b/edsnlp/pipes/core/contextual_matcher/contextual_matcher.py
@@ -1,8 +1,7 @@
import copy
import re
import warnings
-from functools import lru_cache
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Dict, List, Optional, Union
from confit import VisibleDeprecationWarning
from loguru import logger
@@ -21,30 +20,6 @@
from .models import FullConfig, SingleAssignModel, SingleConfig
-@lru_cache(64)
-def get_window(
- doclike: Union[Doc, Span], window: Tuple[int, int], limit_to_sentence: bool
-):
- """
- Generate a window around the first parameter
- """
- start_limit = doclike.sent.start if limit_to_sentence else 0
- end_limit = doclike.sent.end if limit_to_sentence else len(doclike.doc)
-
- start = (
- max(doclike.start + window[0], start_limit)
- if window and window[0] is not None
- else start_limit
- )
- end = (
- min(doclike.end + window[1], end_limit)
- if window and window[0] is not None
- else end_limit
- )
-
- return doclike.doc[start:end]
-
-
class ContextualMatcher(BaseNERComponent):
"""
Allows additional matching in the surrounding context of the main match group,
@@ -252,11 +227,7 @@ def filter_one(self, span: Span) -> Span:
source = span.label_
to_keep = True
for exclude in self.patterns[source].exclude:
- snippet = get_window(
- doclike=span,
- window=exclude.window,
- limit_to_sentence=exclude.limit_to_sentence,
- )
+ snippet = exclude.window(span)
if next(exclude.matcher(snippet, as_spans=True), None) is not None:
to_keep = False
@@ -264,11 +235,7 @@ def filter_one(self, span: Span) -> Span:
break
for include in self.patterns[source].include:
- snippet = get_window(
- doclike=span,
- window=include.window,
- limit_to_sentence=include.limit_to_sentence,
- )
+ snippet = include.window(span)
if next(include.matcher(snippet, as_spans=True), None) is None:
to_keep = False
@@ -308,13 +275,7 @@ def assign_one(self, span: Span) -> Span:
for assign in self.patterns[source].assign:
assign: SingleAssignModel
window = assign.window
- limit_to_sentence = assign.limit_to_sentence
-
- snippet = get_window(
- doclike=span,
- window=window,
- limit_to_sentence=limit_to_sentence,
- )
+ snippet = window(span)
matcher: RegexMatcher = assign.matcher
if matcher is not None:
diff --git a/edsnlp/pipes/core/contextual_matcher/models.py b/edsnlp/pipes/core/contextual_matcher/models.py
index f4aa44edb..3db6c994f 100644
--- a/edsnlp/pipes/core/contextual_matcher/models.py
+++ b/edsnlp/pipes/core/contextual_matcher/models.py
@@ -1,37 +1,14 @@
import re
-from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, List, Optional, Union
import regex
from pydantic import BaseModel, Extra, validator
from edsnlp.matchers.utils import ListOrStr
-from edsnlp.utils.span_getters import SpanGetterArg
+from edsnlp.utils.span_getters import Context, SentenceContext, SpanGetterArg
from edsnlp.utils.typing import AsList
Flags = Union[re.RegexFlag, int]
-Window = Union[
- Tuple[int, int],
- List[int],
- int,
-]
-
-
-def normalize_window(cls, v):
- if v is None:
- return v
- if isinstance(v, list):
- assert (
- len(v) == 2
- ), "`window` should be a tuple/list of two integer, or a single integer"
- v = tuple(v)
- if isinstance(v, int):
- assert v != 0, "The provided `window` should not be 0"
- if v < 0:
- return (v, 0)
- if v > 0:
- return (0, v)
- assert v[0] < v[1], "The provided `window` should contain at least 1 token"
- return v
class AssignDict(dict):
@@ -101,9 +78,10 @@ class SingleExcludeModel(BaseModel):
----------
regex: ListOrStr
A single Regex or a list of Regexes
- window: Optional[Window]
+ window: Optional[Context]
Size of the context to use (in number of words). You can provide the window as:
+ - A [context string][context-string]
- A positive integer, in this case the used context will be taken **after**
the extraction
- A negative integer, in this case the used context will be taken **before**
@@ -121,8 +99,8 @@ class SingleExcludeModel(BaseModel):
"""
regex: ListOrStr = []
- window: Optional[Window] = None
- limit_to_sentence: Optional[bool] = True
+ limit_to_sentence: Optional[bool] = None
+ window: Optional[Context] = None
regex_flags: Optional[Flags] = None
regex_attr: Optional[str] = None
matcher: Optional[Any] = None
@@ -133,7 +111,20 @@ def exclude_regex_validation(cls, v):
v = [v]
return v
- _normalize_window = validator("window", allow_reuse=True)(normalize_window)
+ @validator("limit_to_sentence", pre=True, always=True)
+ def backward_compat_auto_limit_to_sentence(cls, v, values):
+ if (
+ isinstance(values.get("window"), (type(None), int, tuple, list))
+ and v is None
+ ):
+ v = True
+ return v
+
+ @validator("window", always=True)
+ def backward_compat_intersect_sentence(cls, v, values):
+ if values.get("limit_to_sentence"):
+ v = v & SentenceContext(0, 0)
+ return v
class SingleIncludeModel(BaseModel):
@@ -146,9 +137,10 @@ class SingleIncludeModel(BaseModel):
----------
regex: ListOrStr
A single Regex or a list of Regexes
- window: Optional[Window]
+ window: Optional[Context]
Size of the context to use (in number of words). You can provide the window as:
+ - A [context string][context-string]
- A positive integer, in this case the used context will be taken **after**
the extraction
- A negative integer, in this case the used context will be taken **before**
@@ -166,8 +158,8 @@ class SingleIncludeModel(BaseModel):
"""
regex: ListOrStr = []
- window: Optional[Window] = None
- limit_to_sentence: Optional[bool] = True
+ limit_to_sentence: Optional[bool] = None
+ window: Optional[Context] = None
regex_flags: Optional[Flags] = None
regex_attr: Optional[str] = None
matcher: Optional[Any] = None
@@ -178,7 +170,20 @@ def exclude_regex_validation(cls, v):
v = [v]
return v
- _normalize_window = validator("window", allow_reuse=True)(normalize_window)
+ @validator("limit_to_sentence", pre=True, always=True)
+ def backward_compat_auto_limit_to_sentence(cls, v, values):
+ if (
+ isinstance(values.get("window"), (type(None), int, tuple, list))
+ and v is None
+ ):
+ v = True
+ return v
+
+ @validator("window", always=True)
+ def backward_compat_intersect_sentence(cls, v, values):
+ if values.get("limit_to_sentence"):
+ v = v & SentenceContext(0, 0)
+ return v
class ExcludeModel(AsList[SingleExcludeModel]):
@@ -204,9 +209,10 @@ class SingleAssignModel(BaseModel):
----------
name: ListOrStr
A name (string)
- window: Optional[Window]
+ window: Optional[Context]
Size of the context to use (in number of words). You can provide the window as:
+ - A [context string][context-string]
- A positive integer, in this case the used context will be taken **after**
the extraction
- A negative integer, in this case the used context will be taken **before**
@@ -217,7 +223,7 @@ class SingleAssignModel(BaseModel):
span_getter: Optional[SpanGetterArg]
A span getter to pick the assigned spans from already extracted entities
in the doc.
- regex: Optional[Window]
+ regex: Optional[Context]
A dictionary where keys are labels and values are **Regexes with a single
capturing group**
replace_entity: Optional[bool]
@@ -235,8 +241,8 @@ class SingleAssignModel(BaseModel):
name: str
regex: Optional[str] = None
span_getter: Optional[SpanGetterArg] = None
- window: Optional[Window] = None
- limit_to_sentence: Optional[bool] = True
+ limit_to_sentence: Optional[bool] = None
+ window: Optional[Context] = None
regex_flags: Optional[Flags] = None
regex_attr: Optional[str] = None
replace_entity: bool = False
@@ -259,7 +265,20 @@ def check_single_regex_group(cls, pat):
return pat
- _normalize_window = validator("window", allow_reuse=True)(normalize_window)
+ @validator("limit_to_sentence", pre=True, always=True)
+ def backward_compat_auto_limit_to_sentence(cls, v, values):
+ if (
+ isinstance(values.get("window"), (type(None), int, tuple, list))
+ and v is None
+ ):
+ v = True
+ return v
+
+ @validator("window", always=True)
+ def backward_compat_intersect_sentence(cls, v, values):
+ if values.get("limit_to_sentence"):
+ v = v & SentenceContext(0, 0)
+ return v
class AssignModel(AsList[SingleAssignModel]):
diff --git a/edsnlp/utils/span_getters.py b/edsnlp/utils/span_getters.py
index 6a39c3332..e0e9e3b1c 100644
--- a/edsnlp/utils/span_getters.py
+++ b/edsnlp/utils/span_getters.py
@@ -1,3 +1,4 @@
+import abc
from collections import defaultdict
from typing import (
TYPE_CHECKING,
@@ -11,6 +12,7 @@
Union,
)
+import numpy as np
from pydantic import NonNegativeInt
from spacy.tokens import Doc, Span
@@ -303,3 +305,152 @@ def __call__(self, span: Union[Doc, Span]) -> Union[Span, List[Span]]:
end = min(len(span.doc), max(end, max_end_sent))
return span.doc[start:end]
+
+
+class ContextMeta(abc.ABCMeta):
+ pass
+
+
+class Context(abc.ABC, metaclass=ContextMeta):
+ @abc.abstractmethod
+ def __call__(self, span: Span) -> Span:
+ pass
+
+ # logical ops
+ def __and__(self, other: "Context"):
+ # fmt: off
+ return IntersectionContext([
+ *(self.contexts if isinstance(self, IntersectionContext) else (self,)),
+ *(other.contexts if isinstance(other, IntersectionContext) else (other,))
+ ])
+ # fmt: on
+
+ def __or__(self, other: "Context"):
+ # fmt: off
+ return UnionContext([
+ *(self.contexts if isinstance(self, UnionContext) else (self,)),
+ *(other.contexts if isinstance(other, UnionContext) else (other,))
+ ])
+ # fmt: on
+
+ @classmethod
+ def parse(cls, query):
+ return eval(
+ query,
+ {"__builtins__": None},
+ {
+ "words": WordContext,
+ "sents": SentenceContext,
+ },
+ )
+
+ @classmethod
+ def validate(cls, obj, config=None):
+ if isinstance(obj, str):
+ return cls.parse(obj)
+ if isinstance(obj, tuple):
+ assert len(obj) == 2
+ return WordContext(*obj)
+ if isinstance(obj, int):
+ assert obj != 0, "The provided `window` should not be 0"
+ return WordContext(obj, 0) if obj < 0 else WordContext(0, obj)
+ raise ValueError(f"Invalid context: {obj}")
+
+ @classmethod
+ def __get_validators__(cls):
+ yield cls.validate
+
+
+class LeafContextMeta(ContextMeta):
+ def __getitem__(cls, item) -> Span:
+ assert isinstance(item, slice)
+ before = item.start
+ after = item.stop
+ return cls(before, after)
+
+
+class LeafContext(Context, metaclass=LeafContextMeta):
+ pass
+
+
+class WordContext(LeafContext):
+ def __init__(
+ self,
+ before: Optional[int] = None,
+ after: Optional[int] = None,
+ ):
+ self.before = before
+ self.after = after
+
+ def __call__(self, span):
+ start = span.start + self.before if self.before is not None else 0
+ end = span.end + self.after if self.after is not None else len(span.doc)
+ return span.doc[max(0, start) : min(len(span.doc), end)]
+
+ def __repr__(self):
+ return "words[{}:{}]".format(self.before, self.after)
+
+
+class SentenceContext(LeafContext):
+ def __init__(
+ self,
+ before: Optional[int] = None,
+ after: Optional[int] = None,
+ ):
+ self.before = before
+ self.after = after
+
+ def __call__(self, span):
+ sent_starts = span.doc.to_array("SENT_START") == 1
+ sent_indices = sent_starts.cumsum()
+ sent_indices = sent_indices - sent_indices[span.start]
+
+ start_idx = end_idx = None
+ if self.before is not None:
+ start = sent_starts & (sent_indices == self.before)
+ x = np.flatnonzero(start)
+ start_idx = x[-1] if len(x) else 0
+
+ if self.after is not None:
+ end = sent_starts & (sent_indices == self.after + 1)
+ x = np.flatnonzero(end)
+ end_idx = x[0] - 1 if len(x) else len(span.doc)
+
+ return span.doc[start_idx:end_idx]
+
+ def __repr__(self):
+ return "sents[{}:{}]".format(self.before, self.after)
+
+
+class UnionContext(Context):
+ def __init__(
+ self,
+ contexts: AsList[Context],
+ ):
+ self.contexts = contexts
+
+ def __call__(self, span):
+ results = [context(span) for context in self.contexts]
+ min_word = min([span.start for span in results])
+ max_word = max([span.end for span in results])
+ return span.doc[min_word:max_word]
+
+ def __repr__(self):
+ return " | ".join(repr(context) for context in self.contexts)
+
+
+class IntersectionContext(Context):
+ def __init__(
+ self,
+ contexts: AsList[Context],
+ ):
+ self.contexts = contexts
+
+ def __call__(self, span):
+ results = [context(span) for context in self.contexts]
+ min_word = max([span.start for span in results])
+ max_word = min([span.end for span in results])
+ return span.doc[min_word:max_word]
+
+ def __repr__(self):
+ return " & ".join(repr(context) for context in self.contexts)
diff --git a/tests/pipelines/core/test_contextual_matcher.py b/tests/pipelines/core/test_contextual_matcher.py
index 7f4aaf6e7..674890d38 100644
--- a/tests/pipelines/core/test_contextual_matcher.py
+++ b/tests/pipelines/core/test_contextual_matcher.py
@@ -1,8 +1,12 @@
+import os
+
import pytest
from edsnlp.utils.examples import parse_example
from edsnlp.utils.extensions import rgetattr
+os.environ["CONFIT_DEBUG"] = "1"
+
EXAMPLES = [
"""
Le patient présente une métastasis sur un cancer métastasé au stade 3 voire au stade 4.
@@ -151,12 +155,11 @@
(False, False, "keep_last", None),
(False, False, "keep_last", "keep_first"),
(False, False, "keep_last", "keep_last"),
-]
+][:1]
@pytest.mark.parametrize("params,example", list(zip(ALL_PARAMS, EXAMPLES)))
def test_contextual(blank_nlp, params, example):
-
include_assigned, replace_entity, reduce_mode_stage, reduce_mode_metastase = params
blank_nlp.add_pipe(
@@ -225,9 +228,7 @@ def test_contextual(blank_nlp, params, example):
assert len(doc.ents) == len(entities)
for entity, ent in zip(entities, doc.ents):
-
for modifier in entity.modifiers:
-
assert (
rgetattr(ent, modifier.key) == modifier.value
), f"{modifier.key} labels don't match."