Skip to content

Commit

Permalink
adding possibility to select cim10 and atc in eds.cim10 and eds.drugs
Browse files Browse the repository at this point in the history
  • Loading branch information
svittoz committed Aug 28, 2024
1 parent fff384a commit c804921
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 7 deletions.
8 changes: 6 additions & 2 deletions edsnlp/pipes/ner/cim10/factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, List

from typing_extensions import Literal

Expand Down Expand Up @@ -28,6 +28,7 @@ def create_component(
name: str = "cim10",
*,
attr: str = "NORM",
cim10: List[str] = None,
ignore_excluded: bool = False,
ignore_space_tokens: bool = False,
term_matcher: Literal["exact", "simstring"] = "exact",
Expand Down Expand Up @@ -75,6 +76,9 @@ def create_component(
The pipeline object
name : str
The name of the component
cim10 : str
List of cim10 to retrieve. If None, all cim10 will be searched,
resulting in higher computation time.
attr : str
The default attribute to use for matching.
ignore_excluded : bool
Expand Down Expand Up @@ -104,7 +108,7 @@ def create_component(
nlp=nlp,
name=name,
regex=dict(),
terms=get_patterns(),
terms=get_patterns(cim10),
attr=attr,
ignore_excluded=ignore_excluded,
ignore_space_tokens=ignore_space_tokens,
Expand Down
4 changes: 3 additions & 1 deletion edsnlp/pipes/ner/cim10/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from edsnlp import BASE_DIR


def get_patterns() -> Dict[str, List[str]]:
def get_patterns(cim10: List[str] = None) -> Dict[str, List[str]]:
df = pd.read_csv(BASE_DIR / "resources" / "cim10.csv.gz")

df["code_pattern"] = df["code"]
Expand All @@ -30,4 +30,6 @@ def get_patterns() -> Dict[str, List[str]]:

patterns = df.groupby("code")["patterns"].agg(list).to_dict()

patterns = {k: v for k, v in patterns.items() if k in cim10} if cim10 else patterns

return patterns
8 changes: 6 additions & 2 deletions edsnlp/pipes/ner/drugs/factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, List

from typing_extensions import Literal

Expand Down Expand Up @@ -28,6 +28,7 @@ def create_component(
name: str = "drugs",
*,
attr: str = "NORM",
atc: List[str] = None,
ignore_excluded: bool = False,
ignore_space_tokens: bool = False,
term_matcher: Literal["exact", "simstring"] = "exact",
Expand Down Expand Up @@ -83,6 +84,9 @@ def create_component(
The name of the component
attr : str
The default attribute to use for matching.
atc : str
List of atc to retrieve. If None, all atc will be searched,
resulting in higher computation time.
ignore_excluded : bool
Whether to skip excluded tokens (requires an upstream
pipeline to mark excluded tokens).
Expand Down Expand Up @@ -111,7 +115,7 @@ def create_component(
nlp=nlp,
name=name,
regex=dict(),
terms=get_patterns(),
terms=get_patterns(atc),
attr=attr,
ignore_excluded=ignore_excluded,
ignore_space_tokens=ignore_space_tokens,
Expand Down
13 changes: 11 additions & 2 deletions edsnlp/pipes/ner/drugs/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@
drugs_file = BASE_DIR / "resources" / "drugs.json"


def get_patterns() -> Dict[str, List[str]]:
def filter_dict_by_keys(D: Dict[str, List[str]], L: List[str]):
filtered_dict = {
k: v for k, v in D.items() if any(k.startswith(prefix) for prefix in L)
}
return filtered_dict


def get_patterns(atc: List[str] = None) -> Dict[str, List[str]]:
with open(drugs_file, "r") as f:
return json.load(f)
patterns = json.load(f)
patterns = {k: v for k, v in patterns.items() if k in atc} if atc else patterns
return patterns

0 comments on commit c804921

Please sign in to comment.