From 53bcc0e70914b8ecf7db93f605c16ffac8a98035 Mon Sep 17 00:00:00 2001 From: Perceval Wajsburt Date: Fri, 3 May 2024 08:53:23 +0000 Subject: [PATCH] feat: add required & include parameter & support for span_getter in eds.contextual_matcher --- changelog.md | 2 + docs/assets/stylesheets/extra.css | 17 ++ docs/pipes/core/contextual-matcher.md | 70 +---- edsnlp/matchers/regex.py | 4 +- .../contextual_matcher/contextual_matcher.py | 286 ++++++++++-------- .../pipes/core/contextual_matcher/models.py | 231 +++++++++++--- edsnlp/utils/typing.py | 3 +- pyproject.toml | 4 +- 8 files changed, 381 insertions(+), 236 deletions(-) diff --git a/changelog.md b/changelog.md index 2ff94fc8b..f1ad038bb 100644 --- a/changelog.md +++ b/changelog.md @@ -6,6 +6,8 @@ - Expose the defaults patterns of `eds.negation`, `eds.hypothesis`, `eds.family`, `eds.history` and `eds.reported_speech` under a `eds.negation.default_patterns` attribute - Added a `context_getter` SpanGetter argument to the `eds.matcher` class to only retrieve entities inside the spans returned by the getter +- Added a `filter_expr` parameter to scorers to filter the documents to score +- Added a new `required` field to `eds.contextual_matcher` assign patterns to only match if the required field has been found, and an `include` parameter (similar to `exclude`) to search for required patterns without assigning them to the entity ## v0.11.2 diff --git a/docs/assets/stylesheets/extra.css b/docs/assets/stylesheets/extra.css index 0193f144b..e856dba66 100644 --- a/docs/assets/stylesheets/extra.css +++ b/docs/assets/stylesheets/extra.css @@ -166,3 +166,20 @@ body, input { .md-typeset code a:not(.md-annotation__index) { border-bottom: 1px dashed var(--md-typeset-a-color); } + +.doc-param-details .subdoc { + padding: 0; + box-shadow: none; + border-color: var(--md-typeset-table-color); +} + +.doc-param-details .subdoc > div > div > div> table { + padding: 0; + box-shadow: none; + border: none; +} + +.doc-param-details .subdoc > summary { + margin: 0; + font-weight: normal; +} diff --git a/docs/pipes/core/contextual-matcher.md b/docs/pipes/core/contextual-matcher.md index 8f7e21c29..7a5697eda 100644 --- a/docs/pipes/core/contextual-matcher.md +++ b/docs/pipes/core/contextual-matcher.md @@ -206,74 +206,6 @@ Let us see what we can get from this pipeline with a few examples However, most of the configuration is provided in the `patterns` key, as a **pattern dictionary** or a **list of pattern dictionaries** -## The pattern dictionary - -### Description - -A patterr is a nested dictionary with the following keys: - -=== "`source`" - - A label describing the pattern - -=== "`regex`" - - A single Regex or a list of Regexes - -=== "`regex_attr`" - - An attributes to overwrite the given `attr` when matching with Regexes. - -=== "`terms`" - - A single term or a list of terms (for exact matches) - -=== "`exclude`" - - A dictionary (or list of dictionaries) to define exclusion rules. Exclusion rules are given as Regexes, and if a - match is found in the surrounding context of an extraction, the extraction is removed. Each dictionary should have the following keys: - - === "`window`" - - Size of the context to use (in number of words). You can provide the window as: - - - A positive integer, in this case the used context will be taken **after** the extraction - - A negative integer, in this case the used context will be taken **before** the extraction - - A tuple of integers `(start, end)`, in this case the used context will be the snippet from `start` tokens before the extraction to `end` tokens after the extraction - - === "`regex`" - - A single Regex or a list of Regexes. - -=== "`assign`" - - A dictionary to refine the extraction. Similarily to the `exclude` key, you can provide a dictionary to - use on the context **before** and **after** the extraction. - - === "`name`" - - A name (string) - - === "`window`" - - Size of the context to use (in number of words). You can provide the window as: - - - A positive integer, in this case the used context will be taken **after** the extraction - - A negative integer, in this case the used context will be taken **before** the extraction - - A tuple of integers `(start, end)`, in this case the used context will be the snippet from `start` tokens before the extraction to `end` tokens after the extraction - - === "`regex`" - - A dictionary where keys are labels and values are **Regexes with a single capturing group** - - === "`replace_entity`" - - If set to `True`, the match from the corresponding assign key will be used as entity, instead of the main match. See [this paragraph][the-replace_entity-parameter] - - === "`reduce_mode`" - - Set how multiple assign matches are handled. See the documentation of the [`reduce_mode` parameter][the-reduce_mode-parameter] - ### A full pattern dictionary example ```python @@ -300,6 +232,8 @@ dict( regex=r"(neonatal)", expand_entity=True, window=3, + # keep the extraction only if neonatal is found + required=True, ), dict( name="trans", diff --git a/edsnlp/matchers/regex.py b/edsnlp/matchers/regex.py index 681788535..4c1921238 100644 --- a/edsnlp/matchers/regex.py +++ b/edsnlp/matchers/regex.py @@ -1,6 +1,6 @@ import re from bisect import bisect_left, bisect_right -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from loguru import logger from spacy.tokens import Doc, Span @@ -465,7 +465,7 @@ def __call__( doclike: Union[Doc, Span], as_spans=False, return_groupdict=False, - ) -> Union[Span, Tuple[Span, Dict[str, Any]]]: + ) -> Iterator[Union[Span, Tuple[Span, Dict[str, Any]]]]: """ Performs matching. Yields matches. diff --git a/edsnlp/pipes/core/contextual_matcher/contextual_matcher.py b/edsnlp/pipes/core/contextual_matcher/contextual_matcher.py index a68f343e7..124933f28 100644 --- a/edsnlp/pipes/core/contextual_matcher/contextual_matcher.py +++ b/edsnlp/pipes/core/contextual_matcher/contextual_matcher.py @@ -1,11 +1,9 @@ +import copy import re import warnings -from collections import defaultdict from functools import lru_cache -from operator import attrgetter -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union -import pydantic from confit import VisibleDeprecationWarning from loguru import logger from spacy.tokens import Doc, Span @@ -13,22 +11,36 @@ from edsnlp.core import PipelineProtocol from edsnlp.matchers.phrase import EDSPhraseMatcher from edsnlp.matchers.regex import RegexMatcher, create_span -from edsnlp.matchers.utils import get_text from edsnlp.pipes.base import BaseNERComponent, SpanSetterArg from edsnlp.utils.collections import flatten_once +from edsnlp.utils.doc_to_text import get_text +from edsnlp.utils.span_getters import get_spans +from edsnlp.utils.typing import AsList # noqa: F401 from . import models +from .models import FullConfig, SingleAssignModel, SingleConfig @lru_cache(64) def get_window( doclike: Union[Doc, Span], window: Tuple[int, int], limit_to_sentence: bool ): + """ + Generate a window around the first parameter + """ start_limit = doclike.sent.start if limit_to_sentence else 0 end_limit = doclike.sent.end if limit_to_sentence else len(doclike.doc) - start = max(doclike.start + window[0], start_limit) - end = min(doclike.end + window[1], end_limit) + start = ( + max(doclike.start + window[0], start_limit) + if window and window[0] is not None + else start_limit + ) + end = ( + min(doclike.end + window[1], end_limit) + if window and window[0] is not None + else end_limit + ) return doclike.doc[start:end] @@ -44,8 +56,13 @@ class ContextualMatcher(BaseNERComponent): spaCy `Language` object. name : Optional[str] The name of the pipe - patterns : Union[Dict[str, Any], List[Dict[str, Any]]] - The configuration dictionary + patterns : AsList[SingleConfig] + ??? subdoc "The patterns to match" + + ::: edsnlp.pipes.core.contextual_matcher.models.SingleConfig + options: + only_parameters: "no-header" + show_toc: false assign_as_span : bool Whether to store eventual extractions defined via the `assign` key as Spans or as string @@ -75,7 +92,7 @@ def __init__( nlp: Optional[PipelineProtocol], name: Optional[str] = "contextual_matcher", *, - patterns: Union[Dict[str, Any], List[Dict[str, Any]]], + patterns: FullConfig, assign_as_span: bool = False, alignment_mode: str = "expand", attr: str = "NORM", @@ -104,12 +121,13 @@ def __init__( self.ignore_excluded = ignore_excluded self.ignore_space_tokens = ignore_space_tokens self.alignment_mode = alignment_mode - self.regex_flags = regex_flags + self.regex_flags: Union[re.RegexFlag, int] = regex_flags self.include_assigned = include_assigned # Configuration parsing - patterns = pydantic.parse_obj_as(models.FullConfig, patterns) - self.patterns = {pattern.source: pattern for pattern in patterns} + self.patterns: Dict[str, SingleConfig] = copy.deepcopy( + {pattern.source: pattern for pattern in patterns} + ) # Matchers for the anchors self.phrase_matcher = EDSPhraseMatcher( @@ -146,11 +164,6 @@ def __init__( } ) - self.exclude_matchers = defaultdict( - list - ) # Will contain all the exclusion matchers - self.assign_matchers = defaultdict(list) # Will contain all the assign matchers - # Will contain the reduce mode (for each source and assign matcher) self.reduce_mode = {} @@ -159,71 +172,62 @@ def __init__( self.replace_key = {} for source, p in self.patterns.items(): - p = p.dict() - - for exclude in p["exclude"]: - exclude_matcher = RegexMatcher( - attr=exclude["regex_attr"] or p["regex_attr"] or self.attr, - flags=exclude["regex_flags"] - or p["regex_flags"] - or self.regex_flags, + p: SingleConfig + + for exclude in p.exclude: + exclude.matcher = RegexMatcher( + attr=exclude.regex_attr or p.regex_attr or self.attr, + flags=exclude.regex_flags or p.regex_flags or self.regex_flags, ignore_excluded=ignore_excluded, ignore_space_tokens=ignore_space_tokens, alignment_mode="expand", ) - exclude_matcher.build_patterns(regex={"exclude": exclude["regex"]}) + exclude.matcher.build_patterns(regex={"exclude": exclude.regex}) - self.exclude_matchers[source].append( - dict( - matcher=exclude_matcher, - window=exclude["window"], - limit_to_sentence=exclude["limit_to_sentence"], - ) - ) - - replace_key = None - - for assign in p["assign"]: - assign_matcher = RegexMatcher( - attr=assign["regex_attr"] or p["regex_attr"] or self.attr, - flags=assign["regex_flags"] or p["regex_flags"] or self.regex_flags, + for include in p.include: + include.matcher = RegexMatcher( + attr=include.regex_attr or p.regex_attr or self.attr, + flags=include.regex_flags or p.regex_flags or self.regex_flags, ignore_excluded=ignore_excluded, ignore_space_tokens=ignore_space_tokens, - alignment_mode=alignment_mode, - span_from_group=True, + alignment_mode="expand", ) - assign_matcher.build_patterns( - regex={assign["name"]: assign["regex"]}, - ) + include.matcher.build_patterns(regex={"include": include.regex}) + + replace_key = None - self.assign_matchers[source].append( - dict( - name=assign["name"], - matcher=assign_matcher, - window=assign["window"], - limit_to_sentence=assign["limit_to_sentence"], - replace_entity=assign["replace_entity"], - reduce_mode=assign["reduce_mode"], + for assign in p.assign: + assign.matcher = None + if assign.regex: + assign.matcher = RegexMatcher( + attr=assign.regex_attr or p.regex_attr or self.attr, + flags=assign.regex_flags or p.regex_flags or self.regex_flags, + ignore_excluded=ignore_excluded, + ignore_space_tokens=ignore_space_tokens, + alignment_mode=alignment_mode, + span_from_group=True, ) - ) - if assign["replace_entity"]: + assign.matcher.build_patterns( + regex={assign.name: assign.regex}, + ) + assign.regex = assign.matcher + + if assign.replace_entity: # We know that there is only one assign name # with `replace_entity==True` # from PyDantic validation - replace_key = assign["name"] + replace_key = assign.name + self.reduce_mode[source] = {d.name: d.reduce_mode for d in p.assign} self.replace_key[source] = replace_key - self.reduce_mode[source] = { - d["name"]: d["reduce_mode"] for d in self.assign_matchers[source] - } - - self.set_extensions() - def set_extensions(self) -> None: + """ + Define the extensions used by the component + """ super().set_extensions() if not Span.has_extension("assigned"): Span.set_extension("assigned", default=dict()) @@ -232,8 +236,8 @@ def set_extensions(self) -> None: def filter_one(self, span: Span) -> Span: """ - Filter extracted entity based on the "exclusion filter" mentioned - in the configuration + Filter extracted entity based on the exclusion and inclusion filters of + the configuration. Parameters ---------- @@ -247,22 +251,26 @@ def filter_one(self, span: Span) -> Span: """ source = span.label_ to_keep = True - for matcher in self.exclude_matchers[source]: - window = matcher["window"] - limit_to_sentence = matcher["limit_to_sentence"] + for exclude in self.patterns[source].exclude: snippet = get_window( doclike=span, - window=window, - limit_to_sentence=limit_to_sentence, + window=exclude.window, + limit_to_sentence=exclude.limit_to_sentence, ) - if ( - next( - matcher["matcher"](snippet, as_spans=True), - None, - ) - is not None - ): + if next(exclude.matcher(snippet, as_spans=True), None) is not None: + to_keep = False + logger.trace(f"Entity {span} was filtered out") + break + + for include in self.patterns[source].include: + snippet = get_window( + doclike=span, + window=include.window, + limit_to_sentence=include.limit_to_sentence, + ) + + if next(include.matcher(snippet, as_spans=True), None) is None: to_keep = False logger.trace(f"Entity {span} was filtered out") break @@ -290,18 +298,17 @@ def assign_one(self, span: Span) -> Span: """ if span is None: - yield from [] return source = span.label_ assigned_dict = models.AssignDict(reduce_mode=self.reduce_mode[source]) replace_key = None - for matcher in self.assign_matchers[source]: - attr = self.patterns[source].regex_attr or matcher["matcher"].default_attr - window = matcher["window"] - limit_to_sentence = matcher["limit_to_sentence"] - replace_entity = matcher["replace_entity"] # Boolean + all_assigned_list = [] + for assign in self.patterns[source].assign: + assign: SingleAssignModel + window = assign.window + limit_to_sentence = assign.limit_to_sentence snippet = get_window( doclike=span, @@ -309,53 +316,67 @@ def assign_one(self, span: Span) -> Span: limit_to_sentence=limit_to_sentence, ) - # Getting the matches - assigned_list = list(matcher["matcher"].match(snippet)) - - assigned_list = [ - (span, span, matcher["matcher"].regex[0][0]) - if not match.groups() - else ( - span, - create_span( - doclike=snippet, - start_char=match.start(0), - end_char=match.end(0), - key=matcher["matcher"].regex[0][0], - attr=matcher["matcher"].regex[0][2], - alignment_mode=matcher["matcher"].regex[0][5], - ignore_excluded=matcher["matcher"].regex[0][3], - ignore_space_tokens=matcher["matcher"].regex[0][4], - ), - matcher["matcher"].regex[0][0], - ) - for (span, match) in assigned_list - ] + matcher: RegexMatcher = assign.matcher + if matcher is not None: + # Getting the matches + assigned_list = list(matcher.match(snippet)) + assigned_list = [ + (matched_span, matched_span, matcher.regex[0][0], assign) + if not re_match.groups() + else ( + matched_span, + create_span( + doclike=snippet, + start_char=re_match.start(0), + end_char=re_match.end(0), + key=matcher.regex[0][0], + attr=matcher.regex[0][2], + alignment_mode=matcher.regex[0][5], + ignore_excluded=matcher.regex[0][3], + ignore_space_tokens=matcher.regex[0][4], + ), + matcher.regex[0][0], + assign, + ) + for (matched_span, re_match) in assigned_list + ] + else: + assigned_list = [ + (matched_span, matched_span, assign.name, assign) + for matched_span in get_spans(snippet.doc, assign.span_getter) + if matched_span.start >= snippet.start + and matched_span.end <= snippet.end + ] # assigned_list now contains tuples with # - the first element being the span extracted from the group # - the second element being the full match - if not assigned_list: # No match was found + if assign.required and not assigned_list: + logger.trace(f"Entity {span} was filtered out") + return + + all_assigned_list.extend(assigned_list) + + for assigned in all_assigned_list: + if assigned is None: continue + group_span, full_match_span, value_key, assign = assigned + if assign.replace_entity: + replace_key = value_key + + # Using he overridden `__setitem__` method from AssignDict here: + assigned_dict[value_key] = { + "span": full_match_span, # Full span + "value_span": group_span, # Span of the group + "value_text": get_text( + group_span, + attr=self.patterns[source].regex_attr or self.attr, + ignore_excluded=self.ignore_excluded, + ), # Text of the group + } + logger.trace(f"Assign key {value_key} matched on entity {span}") - for assigned in assigned_list: - if assigned is None: - continue - if replace_entity: - replace_key = assigned[2] - - # Using he overrid `__setitem__` method from AssignDict here: - assigned_dict[assigned[2]] = { - "span": assigned[1], # Full span - "value_span": assigned[0], # Span of the group - "value_text": get_text( - assigned[0], - attr=attr, - ignore_excluded=self.ignore_excluded, - ), # Text of the group - } - logger.trace(f"Assign key {matcher['name']} matched on entity {span}") if replace_key is None and self.replace_key[source] is not None: # There should have been a replacement, but none was found # So we discard the entity @@ -388,13 +409,13 @@ def assign_one(self, span: Span) -> Span: closest = Span( span.doc, - min(expandables, key=attrgetter("start")).start, - max(expandables, key=attrgetter("end")).end, + min(span.start for span in expandables if span is not None), + max(span.end for span in expandables if span is not None), span.label_, ) kept_ents.append(closest) - kept_ents.sort(key=attrgetter("start")) + kept_ents.sort(key=lambda e: e.start) for replaced in kept_ents: # Propagating attributes from the anchor @@ -434,6 +455,19 @@ def assign_one(self, span: Span) -> Span: yield from kept_ents def process_one(self, span): + """ + Processes one span, applying both the filters and the assignments + + Parameters + ---------- + span: + spaCy Span object + + Yields + ------ + span: + Filtered spans, with optional assignments + """ filtered = self.filter_one(span) yield from self.assign_one(filtered) diff --git a/edsnlp/pipes/core/contextual_matcher/models.py b/edsnlp/pipes/core/contextual_matcher/models.py index 4e684189c..f4aa44edb 100644 --- a/edsnlp/pipes/core/contextual_matcher/models.py +++ b/edsnlp/pipes/core/contextual_matcher/models.py @@ -1,11 +1,12 @@ import re -from typing import List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union -import pydantic import regex from pydantic import BaseModel, Extra, validator from edsnlp.matchers.utils import ListOrStr +from edsnlp.utils.span_getters import SpanGetterArg +from edsnlp.utils.typing import AsList Flags = Union[re.RegexFlag, int] Window = Union[ @@ -16,6 +17,8 @@ def normalize_window(cls, v): + if v is None: + return v if isinstance(v, list): assert ( len(v) == 2 @@ -89,13 +92,42 @@ def keep_last(key, value): class SingleExcludeModel(BaseModel): + """ + A dictionary to define exclusion rules. Exclusion rules are given as Regexes, and + if a match is found in the surrounding context of an extraction, the extraction is + removed. Each dictionary should have the following keys: + + Parameters + ---------- + regex: ListOrStr + A single Regex or a list of Regexes + window: Optional[Window] + Size of the context to use (in number of words). You can provide the window as: + + - A positive integer, in this case the used context will be taken **after** + the extraction + - A negative integer, in this case the used context will be taken **before** + the extraction + - A tuple of integers `(start, end)`, in this case the used context will be + the snippet from `start` tokens before the extraction to `end` tokens + after the extraction + limit_to_sentence: Optional[bool] + If set to `True`, the exclusion will be limited to the sentence containing the + extraction + regex_flags: Optional[Flags] + Flags to use when compiling the Regexes + regex_attr: Optional[str] + An attribute to overwrite the given `attr` when matching with Regexes. + """ + regex: ListOrStr = [] - window: Window + window: Optional[Window] = None limit_to_sentence: Optional[bool] = True regex_flags: Optional[Flags] = None regex_attr: Optional[str] = None + matcher: Optional[Any] = None - @validator("regex") + @validator("regex", allow_reuse=True) def exclude_regex_validation(cls, v): if isinstance(v, str): v = [v] @@ -104,52 +136,138 @@ def exclude_regex_validation(cls, v): _normalize_window = validator("window", allow_reuse=True)(normalize_window) -class ExcludeModel: - @classmethod - def item_to_list(cls, v, config): - if not isinstance(v, list): +class SingleIncludeModel(BaseModel): + """ + A dictionary to define inclusion rules. Inclusion rules are given as Regexes, and + if a match isn't found in the surrounding context of an extraction, the extraction + is removed. Each dictionary should have the following keys: + + Parameters + ---------- + regex: ListOrStr + A single Regex or a list of Regexes + window: Optional[Window] + Size of the context to use (in number of words). You can provide the window as: + + - A positive integer, in this case the used context will be taken **after** + the extraction + - A negative integer, in this case the used context will be taken **before** + the extraction + - A tuple of integers `(start, end)`, in this case the used context will be + the snippet from `start` tokens before the extraction to `end` tokens + after the extraction + limit_to_sentence: Optional[bool] + If set to `True`, the exclusion will be limited to the sentence containing the + extraction + regex_flags: Optional[Flags] + Flags to use when compiling the Regexes + regex_attr: Optional[str] + An attribute to overwrite the given `attr` when matching with Regexes. + """ + + regex: ListOrStr = [] + window: Optional[Window] = None + limit_to_sentence: Optional[bool] = True + regex_flags: Optional[Flags] = None + regex_attr: Optional[str] = None + matcher: Optional[Any] = None + + @validator("regex", allow_reuse=True) + def exclude_regex_validation(cls, v): + if isinstance(v, str): v = [v] - return [pydantic.parse_obj_as(SingleExcludeModel, x) for x in v] + return v - @classmethod - def __get_validators__(cls): - yield cls.item_to_list + _normalize_window = validator("window", allow_reuse=True)(normalize_window) + + +class ExcludeModel(AsList[SingleExcludeModel]): + """ + A list of `SingleExcludeModel` objects. If a single config is passed, + it will be automatically converted to a list of a single element. + """ + + +class IncludeModel(AsList[SingleIncludeModel]): + """ + A list of `SingleIncludeModel` objects. If a single config is passed, + it will be automatically converted to a list of a single element. + """ class SingleAssignModel(BaseModel): + """ + A dictionary to refine the extraction. Similarly to the `exclude` key, you can + provide a dictionary to use on the context **before** and **after** the extraction. + + Parameters + ---------- + name: ListOrStr + A name (string) + window: Optional[Window] + Size of the context to use (in number of words). You can provide the window as: + + - A positive integer, in this case the used context will be taken **after** + the extraction + - A negative integer, in this case the used context will be taken **before** + the extraction + - A tuple of integers `(start, end)`, in this case the used context will be the + snippet from `start` tokens before the extraction to `end` tokens after the + extraction + span_getter: Optional[SpanGetterArg] + A span getter to pick the assigned spans from already extracted entities + in the doc. + regex: Optional[Window] + A dictionary where keys are labels and values are **Regexes with a single + capturing group** + replace_entity: Optional[bool] + If set to `True`, the match from the corresponding assign key will be used as + entity, instead of the main match. + See [this paragraph][the-replace_entity-parameter] + reduce_mode: Optional[Flags] + Set how multiple assign matches are handled. See the documentation of the + [`reduce_mode` parameter][the-reduce_mode-parameter] + required: Optional[str] + If set to `True`, the assign key must match for the extraction to be kept. If + it does not match, the extraction is discarded. + """ + name: str - regex: str - window: Window + regex: Optional[str] = None + span_getter: Optional[SpanGetterArg] = None + window: Optional[Window] = None limit_to_sentence: Optional[bool] = True regex_flags: Optional[Flags] = None regex_attr: Optional[str] = None replace_entity: bool = False reduce_mode: Optional[str] = None + required: Optional[bool] = False + + matcher: Optional[Any] = None - @validator("regex") + @validator("regex", allow_reuse=True) def check_single_regex_group(cls, pat): + if pat is None: + return pat compiled_pat = regex.compile( pat ) # Using regex to allow multiple fgroups with same name n_groups = compiled_pat.groups - assert n_groups == 1, ( - "The pattern {pat} should have only one capturing group, not {n_groups}" - ).format( - pat=pat, - n_groups=n_groups, - ) + assert ( + n_groups == 1 + ), f"The pattern {pat} should have exactly one capturing group, not {n_groups}" return pat _normalize_window = validator("window", allow_reuse=True)(normalize_window) -class AssignModel: - @classmethod - def item_to_list(cls, v, config): - if not isinstance(v, list): - v = [v] - return [pydantic.parse_obj_as(SingleAssignModel, x) for x in v] +class AssignModel(AsList[SingleAssignModel]): + """ + A list of `SingleAssignModel` objects that should have at most + one element with `replace_entity=True`. If a single config is passed, + it will be automatically converted to a list of a single element. + """ @classmethod def name_uniqueness(cls, v, config): @@ -167,27 +285,62 @@ def replace_uniqueness(cls, v, config): @classmethod def __get_validators__(cls): - yield cls.item_to_list + yield cls.validate yield cls.name_uniqueness yield cls.replace_uniqueness +if TYPE_CHECKING: + ExcludeModel = List[SingleExcludeModel] # noqa: F811 + IncludeModel = List[SingleIncludeModel] # noqa: F811 + AssignModel = List[SingleAssignModel] # noqa: F811 + + class SingleConfig(BaseModel, extra=Extra.forbid): + """ + A single configuration for the contextual matcher. + + Parameters + ---------- + source : str + A label describing the pattern + regex : ListOrStr + A single Regex or a list of Regexes + regex_attr : Optional[str] + An attributes to overwrite the given `attr` when matching with Regexes. + terms : Union[re.RegexFlag, int] + A single term or a list of terms (for exact matches) + exclude : AsList[SingleExcludeModel] + ??? subdoc "One or more exclusion patterns" + + ::: edsnlp.pipes.core.contextual_matcher.models.SingleExcludeModel + options: + only_parameters: "no-header" + assign : AsList[SingleAssignModel] + ??? subdoc "One or more assignment patterns" + + ::: edsnlp.pipes.core.contextual_matcher.models.SingleAssignModel + options: + only_parameters: "no-header" + + """ + source: str terms: ListOrStr = [] regex: ListOrStr = [] regex_attr: Optional[str] = None regex_flags: Union[re.RegexFlag, int] = None - exclude: Optional[ExcludeModel] = [] - assign: Optional[AssignModel] = [] + exclude: ExcludeModel = [] + include: IncludeModel = [] + assign: AssignModel = [] -class FullConfig: - @classmethod - def pattern_to_list(cls, v, config): - if not isinstance(v, list): - v = [v] - return [pydantic.parse_obj_as(SingleConfig, item) for item in v] +class FullConfig(AsList[SingleConfig]): + """ + A list of `SingleConfig` objects that should have distinct `source` fields. + If a single config is passed, it will be automatically converted to a list of + a single element. + """ @classmethod def source_uniqueness(cls, v, config): @@ -197,5 +350,9 @@ def source_uniqueness(cls, v, config): @classmethod def __get_validators__(cls): - yield cls.pattern_to_list + yield cls.validate yield cls.source_uniqueness + + +if TYPE_CHECKING: + FullConfig = List[SingleConfig] # noqa: F811 diff --git a/edsnlp/utils/typing.py b/edsnlp/utils/typing.py index 6562642dc..68b1fcbdf 100644 --- a/edsnlp/utils/typing.py +++ b/edsnlp/utils/typing.py @@ -9,7 +9,8 @@ class MetaAsList(type): def __init__(cls, name, bases, dct): super().__init__(name, bases, dct) - cls.item = Any + item = next((base.item for base in bases if hasattr(base, "item")), Any) + cls.item = item def __getitem__(self, item): new_type = MetaAsList(self.__name__, (self,), {}) diff --git a/pyproject.toml b/pyproject.toml index 9b347117a..36ac8fbe6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -338,9 +338,9 @@ ignore-property-decorators = false ignore-module = true ignore-nested-functions = true ignore-nested-classes = true -ignore-setters = false +ignore-setters = true fail-under = 40 -exclude = ["setup.py", "docs", "build", "tests"] +exclude = ["setup.py", "docs", "build", "tests", "edsnlp/pipes/core/contextual_matcher/models.py"] verbose = 0 quiet = false whitelist-regex = []