Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Word process and result table generation #150

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion benchmarks/bioid_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ def get_results_tables(
total.loc[:, 'entity_type'] = 'Total'
stats = res_df.groupby('entity_type', as_index=False).sum()
stats = stats[stats['entity_type'] != 'unknown']
stats = stats.append(total, ignore_index=True)
stats = pd.concat([stats, total], ignore_index=True)
stats.loc[:, stats.columns[1:]] = stats[stats.columns[1:]].astype(int)
if match == 'strict':
score_cols = ['top_correct', 'exists_correct']
Expand Down
195 changes: 176 additions & 19 deletions benchmarks/bioid_ner_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
import json
import pathlib
import logging
import pickle
from datetime import datetime
from collections import defaultdict, Counter
import xml.etree.ElementTree as ET
from textwrap import dedent
from typing import List, Dict

import pystow
import pandas as pd
from tqdm import tqdm
Expand All @@ -16,7 +16,7 @@
from gilda.ner import annotate

#from benchmarks.bioid_evaluation import fplx_members
from benchmarks.bioid_evaluation import BioIDBenchmarker
from bioid_evaluation import BioIDBenchmarker

logging.getLogger('gilda.grounder').setLevel('WARNING')
logger = logging.getLogger('bioid_ner_benchmark')
Expand Down Expand Up @@ -51,6 +51,7 @@ def __init__(self):
self.counts_table = None
self.precision_recall = None
self.false_positives_counter = Counter()
self.result = None

def process_xml_files(self):
"""Extract relevant information from XML files."""
Expand All @@ -60,7 +61,9 @@ def process_xml_files(self):
for filename in os.listdir(DATA_DIR):
if filename.endswith('.xml'):
filepath = os.path.join(DATA_DIR, filename)
tree = ET.parse(filepath)
#subprocess.run(['iconv', '-f', 'ASCII', '-t', 'UTF-8', filepath, '-o', filepath], check=True)
with open(filepath, 'r', encoding='utf-8') as file:
tree = ET.parse(file)
root = tree.getroot()
for document in root.findall('.//document'):
doc_id_full = document.find('.//id').text.strip()
Expand Down Expand Up @@ -145,11 +148,11 @@ def annotate_entities_with_gilda(self):
doc_id = item['doc_id']
figure = item['figure']
text = item['text']

# Get the full text for the paper-level disambiguation
full_text = self._get_plaintext(doc_id)

gilda_annotations = annotate(text, context_text=full_text)
gilda_annotations = annotate(text, context_text=full_text,
organisms=self._get_organism_priority(doc_id))

for annotation in gilda_annotations:
total_gilda_annotations += 1
Expand Down Expand Up @@ -199,32 +202,29 @@ def evaluate_gilda_performance(self):
if match_found:
break

if not match_found:
if not match_found and matching_refs != []:
metrics['all_matches']['fp'] += 1
self.false_positives_counter[annotation.text] += 1
if annotation.matches: # Check if there are any matches
metrics['top_match']['fp'] += 1

# False negative calculation using ref dict
# The number entries of annotion in reference with no annotion in grounding
for key, refs in tqdm(ref_dict.items(),
desc="Calculating False Negatives"):
doc_id, figure = key[0], key[1]
gilda_annotations = self.gilda_annotations_map.get((doc_id, figure),
[])
for original_curies, synonyms in refs:
match_found = any(
ann.text == key[2] and
ann.start == key[3] and
ann.end == key[4] and
any(f"{match.term.db}:{match.term.id}" in original_curies or
f"{match.term.db}:{match.term.id}" in synonyms
for match in ann.matches)
for ann in gilda_annotations
match_found = any(
ann.text == key[2] and
ann.start == key[3] and
ann.end == key[4]
for ann in gilda_annotations
)

if not match_found:
metrics['all_matches']['fn'] += 1
metrics['top_match']['fn'] += 1
if not match_found:
metrics['all_matches']['fn'] += 1
metrics['top_match']['fn'] += 1

results = {}
for match_type, counts in metrics.items():
Expand Down Expand Up @@ -292,13 +292,154 @@ def get_tables(self):
self.precision_recall,
self.false_positives_counter)

def check_match(self, row):
obj=row['obj']
obj_synonyms = row['obj_synonyms']
groundings = row['groundings']
if obj_synonyms is None or groundings is None:
return False
for elem in obj_synonyms:
for tup in groundings:
if elem == tup[0]:
return True
for elem in obj:
for tup in groundings:
if elem == tup[0]:
return True
return False

def generate_result_table(self):

ref_dict = defaultdict()

for _, row in self.annotations_df.iterrows():
key = (str(row['don_article']), row['figure'], row['text'],
row['first left'], row['last right'])
ref_dict[key] = (row['obj'], row['obj_synonyms'])

text_list, obj_synonyms_list, don_articles_list = [], [], []
groundings_list, entity_type_list, obj_list = [], [], []
figure_list = []
all_annotation = {}

for (doc_id, figure), annotations in (
tqdm(self.gilda_annotations_map.items(),
desc="Getting result")):
for annotation in annotations:
key = (doc_id, figure, annotation.text, annotation.start,
annotation.end)
all_annotation[key] = annotation

matching_refs = ref_dict.get(key, None)

groundings = []
if matching_refs:
obj = matching_refs[0]
obj_synonyms = matching_refs[1]
else:
obj, obj_synonyms = None, None

text = annotation.text
for scored_match in annotation.matches:
curies = []
curie = f"{scored_match.term.db}:{scored_match.term.id}"
score = scored_match.score
groundings.append((curie, score))
curies.append(curie)

if obj:
entity_type = self._get_entity_type(obj)
else:
entity_type = None

obj_list.append(obj)
text_list.append(text)
obj_synonyms_list.append(obj_synonyms)
don_articles_list.append(doc_id)
figure_list.append(figure)
entity_type_list.append(entity_type)
groundings_list.append(groundings)

for key, refs in tqdm(ref_dict.items(),
desc="Things in reference but not in grounding"):
doc_id, figure = key[0], key[1]
text, start, end = key[2], key[3], key[4]

if not all_annotation.get((doc_id, figure, text, start, end)):
obj_list.append(refs[0]) # ([i[0] for i in refs])
entity_type = self._get_entity_type(refs[0])
entity_type_list.append(entity_type)
text_list.append(key[2])
figure_list.append(key[1])
obj_synonyms_list.append(refs[1]) # ([i[1] for i in refs])
don_articles_list.append(key[0])
groundings_list.append(None)


data = {
'text': text_list,
'obj': obj_list,
'obj_synonyms': obj_synonyms_list,
'don_article': don_articles_list,
'figure': figure_list,
'entity_type': entity_type_list,
'groundings': groundings_list,
}
self.result = pd.DataFrame(data)
self.result['match'] = self.result.apply(self.check_match, axis=1)
self.result = self.result.sort_values(by=['don_article', 'figure'])
def get_entity_result(self):
df = self.result
#True Positives
df_tp = df[(df['obj'].notna()) & (df['groundings'].notna()) & (df['match'] == True)]
true_positive_counts = df_tp.groupby('entity_type').size().reset_index(name='true_positive_count')
#False Negatives
df_fn = df[(df['obj'].notna()) & (df['groundings'].isna())]
false_negative_counts = df_fn.groupby('entity_type').size().reset_index(name='false_negative_count')
#False Positives
df_fp = df[(df['obj'].notna()) & (df['groundings'].notna()) & (df['match'] == False)]
false_positive_counts = df_fp.groupby('entity_type').size().reset_index(name='false_positive_count')
#Merge
merged_df = pd.merge(true_positive_counts, false_negative_counts, on='entity_type', how='outer').fillna(0)
merged_df = pd.merge(merged_df, false_positive_counts, on='entity_type', how='outer').fillna(0)
#Recall
merged_df['recall'] = merged_df['true_positive_count'] / (
merged_df['true_positive_count'] + merged_df['false_negative_count'])
#Precision
merged_df['precision'] = merged_df['true_positive_count'] / (
merged_df['true_positive_count'] + merged_df['false_positive_count'])

total_tp = merged_df['true_positive_count'].sum()
total_fn = merged_df['false_negative_count'].sum()
total_fp = merged_df['false_positive_count'].sum()
total_recall = total_tp / (total_tp + total_fn)
total_precision = total_tp / (total_tp + total_fp)

total_row = pd.DataFrame({
'entity_type': ['Total'],
'true_positive_count': [total_tp],
'false_negative_count': [total_fn],
'false_positive_count': [total_fp],
'recall': [total_recall],
'precision': [total_precision]
})
final_df = pd.concat([merged_df, total_row], ignore_index=True)
final_df = final_df[
['entity_type', 'true_positive_count', 'false_positive_count', 'false_negative_count', 'precision',
'recall']]
return final_df


def main(results: str = RESULTS_DIR):
results_path = os.path.expandvars(os.path.expanduser(results))
os.makedirs(results_path, exist_ok=True)

benchmarker = BioIDNERBenchmarker()
benchmarker.processed_data.to_csv("/Users/haohangyan/Desktop/repo/gilda/benchmarks/processed_data.tsv", sep='\t', index=False)
benchmarker.annotate_entities_with_gilda()
df = pd.DataFrame(list(benchmarker.gilda_annotations_map.items()), columns=['Key', 'Value'])
df.to_csv("/Users/haohangyan/Desktop/repo/gilda/benchmarks/gilda_annotations_map.tsv", sep='\t', index=False)
benchmarker.generate_result_table()
benchmarker.evaluate_gilda_performance()
counts, precision_recall, false_positives_counter = benchmarker.get_tables()

Expand All @@ -311,6 +452,7 @@ def main(results: str = RESULTS_DIR):

outname = f'benchmark_{time}'
result_stub = pathlib.Path(results_path).joinpath(outname)
entity_result = benchmarker.get_entity_result()

caption0 = dedent(f"""\
# Gilda NER Benchmarking
Expand Down Expand Up @@ -339,6 +481,16 @@ def main(results: str = RESULTS_DIR):
table2 = precision_recall.to_markdown(index=False)

caption3 = dedent("""\
## Table 3

Precision and recall values for Gilda performance by entity type. Values
are given both for the case where Gilda is considered correct only if the
top grounding matches and the case where Gilda is considered correct if
any of its groundings match.
""")
table_by_entity = entity_result.to_markdown(index=False)

caption4 = dedent("""\
## 50 Most Common False Positive Words

A list of 50 most common false positive annotations created by Gilda.
Expand All @@ -351,7 +503,8 @@ def main(results: str = RESULTS_DIR):
caption0,
caption1, table1,
caption2, table2,
caption3, false_positives_list
caption3, table_by_entity,
caption4, false_positives_list
])

md_path = result_stub.with_suffix(".md")
Expand All @@ -361,6 +514,10 @@ def main(results: str = RESULTS_DIR):
counts.to_csv(result_stub.with_suffix(".counts.csv"), index=False)
precision_recall.to_csv(result_stub.with_suffix(".precision_recall.csv"),
index=False)
benchmarker.result.to_csv(
result_stub.with_suffix(".ner_result.tsv"),
sep='\t', index=False)

print(f'Results saved to {results_path}')


Expand Down
2 changes: 1 addition & 1 deletion gilda/grounder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .term import Term, get_identifiers_curie, get_identifiers_url
from .process import normalize, replace_dashes, replace_greek_uni, \
replace_greek_latin, replace_greek_spelled_out, depluralize, \
replace_roman_arabic
replace_roman_arabic, strip_greek_letters
from .scorer import Match, generate_match, score
from .resources import get_gilda_models, get_grounding_terms

Expand Down
24 changes: 22 additions & 2 deletions gilda/ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
the extension ``.txt`` and the annotations in a file with the
same name but extension ``.ann``.
"""

import re
from typing import List, Set
import os

Expand All @@ -54,7 +54,7 @@

from gilda import get_grounder
from gilda.grounder import Annotation
from gilda.process import normalize
from gilda.process import normalize, strip_greek_letters

__all__ = [
"annotate",
Expand All @@ -77,6 +77,10 @@ def _load_stoplist() -> Set[str]:
stop_words = set(stopwords.words('english'))
stop_words.update(_load_stoplist())

def preprocess_text(text):
# Replace various types of hyphens with a space
text = re.sub(r'[\u2010\u2011\u2012\u2013\u2014\u2015\u2212\u00AD\u2018\u2019]', ' ', text)
return text

def annotate(
text, *,
Expand Down Expand Up @@ -117,6 +121,7 @@ def annotate(
the text span that was matched, the list of ScoredMatches, and the
start and end character offsets of the text span.
"""
text = preprocess_text(text)
if grounder is None:
grounder = get_grounder()
if sent_split_fun is None:
Expand Down Expand Up @@ -145,6 +150,8 @@ def annotate(
if word in stop_words:
continue
spans = grounder.prefix_index.get(word, set())
if len(word) > 1:
spans.add(1)
if not spans:
continue

Expand All @@ -168,6 +175,19 @@ def annotate(
if len(raw_span) <= 1:
continue
context = text if context_text is None else context_text

if raw_span != strip_greek_letters(raw_span):
matches = grounder.ground(strip_greek_letters(raw_span),
context=context,
organisms=organisms,
namespaces=namespaces)
if matches:
start_coord = sent_start + raw_word_coords[idx][0]
end_coord = sent_start + raw_word_coords[idx + span - 1][1]
annotations.append(Annotation(
strip_greek_letters(raw_span), matches, start_coord, end_coord - 1
))

matches = grounder.ground(raw_span,
context=context,
organisms=organisms,
Expand Down
4 changes: 4 additions & 0 deletions gilda/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ def replace_greek_spelled_out(s):
s = s.replace(greek_uni, greek_spelled_out)
return s

def strip_greek_letters(s):
"""Strip Greek unicode character.
"""
return ''.join(c for c in s if c not in greek_alphabet)

def replace_unicode(s):
"""Replace unicode with ASCII equivalent, except Greek letters.
Expand Down
Empty file added scratch_2.py
Empty file.