diff --git a/benchmarks/bioid_ner_benchmark.py b/benchmarks/bioid_ner_benchmark.py new file mode 100644 index 0000000..7f24e9c --- /dev/null +++ b/benchmarks/bioid_ner_benchmark.py @@ -0,0 +1,368 @@ +import os +import json +import pathlib +import logging +from datetime import datetime +from collections import defaultdict, Counter +import xml.etree.ElementTree as ET +from textwrap import dedent +from typing import List, Dict + +import pystow +import pandas as pd +from tqdm import tqdm + +import gilda +from gilda.ner import annotate + +#from benchmarks.bioid_evaluation import fplx_members +from benchmarks.bioid_evaluation import BioIDBenchmarker + +logging.getLogger('gilda.grounder').setLevel('WARNING') +logger = logging.getLogger('bioid_ner_benchmark') + +# Constants +HERE = os.path.dirname(os.path.abspath(__file__)) + +DATA_DIR = os.path.join(HERE, 'data', 'BioIDtraining_2', 'caption_bioc') +ANNOTATIONS_PATH = os.path.join(HERE, 'data', 'BioIDtraining_2', + 'annotations.csv') +RESULTS_DIR = os.path.join(HERE, 'results', "bioid_ner_performance", + gilda.__version__) +MODULE = pystow.module('gilda', 'biocreative') +URL = ('https://biocreative.bioinformatics.udel.edu/media/store/files/2017' + '/BioIDtraining_2.tar.gz') + +tqdm.pandas() + +BO_MISSING_XREFS = set() + + +class BioIDNERBenchmarker(BioIDBenchmarker): + def __init__(self): + super().__init__() + print("Instantiating benchmarker...") + self.equivalences = self._load_equivalences() + self.paper_level_grounding = defaultdict(set) + self.processed_data = self.process_xml_files() # xml files processes + self.annotations_df = self._process_annotations_table() # csvannotations + self.gilda_annotations_map = defaultdict(list) + self.annotations_count = 0 + self.counts_table = None + self.precision_recall = None + self.false_positives_counter = Counter() + + def process_xml_files(self): + """Extract relevant information from XML files.""" + print("Extracting information from XML files...") + data = [] + total_annotations = 0 + for filename in os.listdir(DATA_DIR): + if filename.endswith('.xml'): + filepath = os.path.join(DATA_DIR, filename) + tree = ET.parse(filepath) + root = tree.getroot() + for document in root.findall('.//document'): + doc_id_full = document.find('.//id').text.strip() + # Split the full ID to get don_article and figure + don_article, figure = doc_id_full.split(' ',1) + don_article = don_article + for passage in document.findall('.//passage'): + offset = int(passage.find('.//offset').text) + text = passage.find('.//text').text + annotations = [] + for annotation in passage.findall('.//annotation'): + annot_id = annotation.get('id') + annot_text = annotation.find('.//text').text + annot_type = annotation.find( + './/infon[@key="type"]').text + annot_offset = int( + annotation.find('.//location').attrib['offset']) + annot_length = int( + annotation.find('.//location').attrib['length']) + annotations.append({ + 'annot_id': annot_id, + 'annot_text': annot_text, + 'annot_type': annot_type, + 'annot_offset': annot_offset, + 'annot_length': annot_length, + }) + total_annotations += 1 + data.append({ + 'doc_id': don_article, + 'figure': figure, + 'offset': offset, + 'text': text, + 'annotations': annotations, + }) + print(f"Total annotations in XML files: {total_annotations}") + self.annotations_count = total_annotations + print("Finished extracting information from XML files.") + return pd.DataFrame(data) + + def _load_equivalences(self) -> Dict[str, List[str]]: + with open(os.path.join(HERE, 'data', 'equivalences.json')) as f: + equivalences = json.load(f) + return equivalences + + def _process_annotations_table(self): + """Extract relevant information from annotations table. Modified for + NER. Overrides the super method.""" + print("Extracting information from annotations table...") + df = MODULE.ensure_tar_df( + url=URL, + inner_path='BioIDtraining_2/annotations.csv', + read_csv_kwargs=dict(sep=',', low_memory=False), + ) + # Split entries with multiple groundings then normalize ids + df.loc[:, 'obj'] = df['obj'].apply(self._normalize_ids) + # Add synonyms of gold standard groundings to help match more things + df.loc[:, 'obj_synonyms'] = df['obj'].apply(self.get_synonym_set) + # Create column for entity type + df.loc[:, 'entity_type'] = df.apply(self._get_entity_type_helper, + axis=1) + processed_data = df[['text', 'obj', 'obj_synonyms', 'entity_type', + 'don_article', 'figure', 'annot id', 'first left', + 'last right']] + print("%d rows in processed annotations table." % len(processed_data)) + processed_data = processed_data[processed_data.entity_type + != 'unknown'] + print("%d rows in annotations table without unknowns." % + len(processed_data)) + for don_article, text, synonyms in df[['don_article', 'text', + 'obj_synonyms']].values: + self.paper_level_grounding[don_article, text].update(synonyms) + return processed_data + + def annotate_entities_with_gilda(self): + """Performs NER on the XML files using gilda.annotate()""" + print("Annotating corpus with Gilda...") + + total_gilda_annotations = 0 + for _, item in tqdm(self.processed_data.iterrows(), + total=self.processed_data.shape[0], + desc="Annotating with Gilda"): + doc_id = item['doc_id'] + figure = item['figure'] + text = item['text'] + + # Get the full text for the paper-level disambiguation + full_text = self._get_plaintext(doc_id) + + gilda_annotations = annotate(text, context_text=full_text) + + for annotation in gilda_annotations: + total_gilda_annotations += 1 + + self.gilda_annotations_map[(doc_id, figure)].append(annotation) + + tqdm.write("Finished annotating corpus with Gilda...") + tqdm.write(f"Total Gilda annotations: {total_gilda_annotations}") + + def evaluate_gilda_performance(self): + """Calculates precision, recall, and F1""" + print("Evaluating performance...") + + metrics = { + 'all_matches': {'tp': 0, 'fp': 0, 'fn': 0}, + 'top_match': {'tp': 0, 'fp': 0, 'fn': 0} + } + + + + ref_dict = defaultdict(list) + for _, row in self.annotations_df.iterrows(): + key = (str(row['don_article']), row['figure'], row['text'], + row['first left'], row['last right']) + ref_dict[key].append((set(row['obj']), row['obj_synonyms'])) + + for (doc_id, figure), annotations in ( + tqdm(self.gilda_annotations_map.items(), + desc="Evaluating Annotations")): + for annotation in annotations: + key = (doc_id, figure, annotation.text, annotation.start, + annotation.end) + matching_refs = ref_dict.get(key, []) + + match_found = False + for i, scored_match in enumerate(annotation.matches): + curie = f"{scored_match.term.db}:{scored_match.term.id}" + + for original_curies, synonyms in matching_refs: + if curie in original_curies or curie in synonyms: + metrics['all_matches']['tp'] += 1 + if i == 0: # Top match + metrics['top_match']['tp'] += 1 + match_found = True + break + + if match_found: + break + + if not match_found: + metrics['all_matches']['fp'] += 1 + self.false_positives_counter[annotation.text] += 1 + if annotation.matches: # Check if there are any matches + metrics['top_match']['fp'] += 1 + + # False negative calculation using ref dict + for key, refs in tqdm(ref_dict.items(), + desc="Calculating False Negatives"): + doc_id, figure = key[0], key[1] + gilda_annotations = self.gilda_annotations_map.get((doc_id, figure), + []) + for original_curies, synonyms in refs: + match_found = any( + ann.text == key[2] and + ann.start == key[3] and + ann.end == key[4] and + any(f"{match.term.db}:{match.term.id}" in original_curies or + f"{match.term.db}:{match.term.id}" in synonyms + for match in ann.matches) + for ann in gilda_annotations + ) + + if not match_found: + metrics['all_matches']['fn'] += 1 + metrics['top_match']['fn'] += 1 + + results = {} + for match_type, counts in metrics.items(): + precision = counts['tp'] / (counts['tp'] + counts['fp']) \ + if ((counts['tp'] + counts['fp']) > 0) else 0 + + recall = counts['tp'] / (counts['tp'] + counts['fn']) \ + if (counts['tp'] + counts['fn']) > 0 else 0 + + f1 = 2 * (precision * recall) / (precision + recall) \ + if (precision + recall) > 0 else 0 + + results[match_type] = { + 'precision': precision, + 'recall': recall, + 'f1': f1 + } + + counts_table = pd.DataFrame([ + { + 'Match Type': 'All Matches', + 'True Positives': metrics['all_matches']['tp'], + 'False Positives': metrics['all_matches']['fp'], + 'False Negatives': metrics['all_matches']['fn'] + }, + { + 'Match Type': 'Top Match', + 'True Positives': metrics['top_match']['tp'], + 'False Positives': metrics['top_match']['fp'], + 'False Negatives': metrics['top_match']['fn'] + } + ]) + + precision_recall = pd.DataFrame([ + { + 'Match Type': 'All Matches', + 'Precision': results['all_matches']['precision'], + 'Recall': results['all_matches']['recall'], + 'F1 Score': results['all_matches']['f1'] + }, + { + 'Match Type': 'Top Match', + 'Precision': results['top_match']['precision'], + 'Recall': results['top_match']['recall'], + 'F1 Score': results['top_match']['f1'] + } + ]) + + self.counts_table = counts_table + self.precision_recall = precision_recall + + os.makedirs(RESULTS_DIR, exist_ok=True) + false_positives_df = pd.DataFrame(self.false_positives_counter.items(), + columns=['False Positive Text', + 'Count']) + false_positives_df = false_positives_df.sort_values(by='Count', + ascending=False) + false_positives_df.to_csv( + os.path.join(RESULTS_DIR, 'false_positives.csv'), index=False) + + print("Finished evaluating performance...") + + def get_tables(self): + return (self.counts_table, + self.precision_recall, + self.false_positives_counter) + + +def main(results: str = RESULTS_DIR): + results_path = os.path.expandvars(os.path.expanduser(results)) + os.makedirs(results_path, exist_ok=True) + + benchmarker = BioIDNERBenchmarker() + benchmarker.annotate_entities_with_gilda() + benchmarker.evaluate_gilda_performance() + counts, precision_recall, false_positives_counter = benchmarker.get_tables() + + print(f"Counts Table:") + print(counts.to_markdown(index=False)) + print(f"Precision and Recall table: ") + print(precision_recall.to_markdown(index=False)) + + time = datetime.now().strftime('%y%m%d-%H%M%S') + + outname = f'benchmark_{time}' + result_stub = pathlib.Path(results_path).joinpath(outname) + + caption0 = dedent(f"""\ + # Gilda NER Benchmarking + + Gilda: v{gilda.__version__} + Date: {time} + """) + + caption1 = dedent("""\ + ## Table 1 + + The counts of true positives, false positives, and false negatives + for Gilda annotations in the corpus where only Gilda's "Top Match" + grounding (top score grounding) returns the correct match and where + any Gilda grounding returns a correct match. + """) + table1 = counts.to_markdown(index=False) + + caption2 = dedent("""\ + ## Table 2 + + Precision, recall, and F1 Score values for Gilda performance where + Gilda's "Top Match" grounding (top score grounding) returns the + correct match and where any Gilda grounding returns a correct match. + """) + table2 = precision_recall.to_markdown(index=False) + + caption3 = dedent("""\ + ## 50 Most Common False Positive Words + + A list of 50 most common false positive annotations created by Gilda. + """) + top_50_false_positives = false_positives_counter.most_common(50) + false_positives_list = '\n'.join( + [f'- {word}: {count}' for word, count in top_50_false_positives]) + + output = '\n\n'.join([ + caption0, + caption1, table1, + caption2, table2, + caption3, false_positives_list + ]) + + md_path = result_stub.with_suffix(".md") + with open(md_path, 'w') as f: + f.write(output) + + counts.to_csv(result_stub.with_suffix(".counts.csv"), index=False) + precision_recall.to_csv(result_stub.with_suffix(".precision_recall.csv"), + index=False) + print(f'Results saved to {results_path}') + + +if __name__ == '__main__': + main() diff --git a/gilda/app/app.py b/gilda/app/app.py index 801f96a..a0caf74 100644 --- a/gilda/app/app.py +++ b/gilda/app/app.py @@ -8,6 +8,7 @@ from gilda import __version__ as version from gilda.grounder import GrounderInput, Grounder from gilda.app.proxies import grounder +from gilda.ner import annotate # NOTE: the Flask REST-X API has to be declared here, below the home endpoint # otherwise it reserves the / base path. @@ -46,15 +47,15 @@ term_model = api.model( "Term", - {'norm_text' : fields.String( + {'norm_text': fields.String( description='The normalized text corresponding to the text entry, ' 'used for lookups.', example='egf receptor'), - 'text' : fields.String( + 'text': fields.String( description='The text entry that was matched.', example='EGF receptor' ), - 'db' : fields.String( + 'db': fields.String( description='The database / namespace corresponding to the ' 'grounded term.', example='HGNC' @@ -97,8 +98,7 @@ scored_match_model = api.model( "ScoredMatch", - {'term': fields.Nested(term_model, - description='The term that was matched'), + {'term': fields.Nested(term_model, description='The term that was matched'), 'url': fields.String( description='Identifiers.org URL for the matched term.', example='https://identifiers.org/hgnc:3236' @@ -120,14 +120,13 @@ } ) - get_names_input_model = api.model( "GetNamesInput", {'db': fields.String( - description="Capitalized name of the database for the grounding, " - "e.g. HGNC.", - required=True, - example='HGNC'), + description="Capitalized name of the database for the grounding, " + "e.g. HGNC.", + required=True, + example='HGNC'), 'id': fields.String( description="Identifier within the given database", required=True, @@ -147,7 +146,7 @@ "different sources.", required=False, example='uniprot' - ) + ) } ) @@ -161,8 +160,8 @@ ner_input_model = api.model('NERInput', { 'text': fields.String(required=True, description='Text on which to perform' ' NER', - example='The EGF receptor binds EGF which is an interaction' - 'important in cancer.'), + example='The EGF receptor binds EGF which is an ' + 'interaction important in cancer.'), 'organisms': fields.List(fields.String, example=['9606'], description='An optional list of taxonomy ' 'species IDs defining a priority list' @@ -185,8 +184,8 @@ }) names_model = fields.List( - fields.String, - example=['EGF receptor', 'EGFR', 'ERBB1', 'Proto-oncogene c-ErbB-1']) + fields.String, + example=['EGF receptor', 'EGFR', 'ERBB1', 'Proto-oncogene c-ErbB-1']) models_model = fields.List( fields.String, @@ -212,7 +211,8 @@ def post(self): text = request.json.get('text') context = request.json.get('context') organisms = request.json.get('organisms') - scored_matches = grounder.ground(text, context=context, organisms=organisms) + scored_matches = grounder.ground(text, context=context, + organisms=organisms) res = [sm.to_json() for sm in scored_matches] return jsonify(res) @@ -238,7 +238,8 @@ def post(self): text = input.get('text') context = input.get('context') organisms = input.get('organisms') - scored_matches = grounder.ground(text, context=context, organisms=organisms) + scored_matches = grounder.ground(text, context=context, + organisms=organisms) all_matches.append([sm.to_json() for sm in scored_matches]) return jsonify(all_matches) @@ -311,7 +312,6 @@ def post(self): return jsonify([annotation.to_json() for annotation in results]) - def get_app(terms: Optional[GrounderInput] = None, *, ui: bool = True) -> Flask: app = Flask(__name__) app.config['RESTX_MASK_SWAGGER'] = False diff --git a/gilda/ner.py b/gilda/ner.py index 39a5009..49dae9e 100644 --- a/gilda/ner.py +++ b/gilda/ner.py @@ -14,6 +14,7 @@ - the `start` position in the text string where the entity starts - the `end` position in the text string where the entity ends + In this example, the two concepts are grounded to FamPlex entries. >>> results[0].text, results[0].matches[0].term.get_curie(), results[0].start, results[0].end @@ -45,7 +46,8 @@ same name but extension ``.ann``. """ -from typing import List +from typing import List, Set +import os from nltk.corpus import stopwords from nltk.tokenize import PunktSentenceTokenizer, TreebankWordTokenizer @@ -60,7 +62,20 @@ "stop_words" ] +STOPLIST_PATH = os.path.join(os.path.dirname(__file__),'resources', + 'ner_stoplist.txt') + + +def _load_stoplist() -> Set[str]: + """Load NER stoplist from file.""" + stoplist_path = STOPLIST_PATH + with open(stoplist_path, 'r') as file: + stoplist = {line.strip() for line in file} + return stoplist + + stop_words = set(stopwords.words('english')) +stop_words.update(_load_stoplist()) def annotate( @@ -149,6 +164,9 @@ def annotate( spaces = ' ' * (c[0] - len(raw_span) - raw_word_coords[idx][0]) raw_span += spaces + rw + # If span is a single character, we don't want to consider it + if len(raw_span) <= 1: + continue context = text if context_text is None else context_text matches = grounder.ground(raw_span, context=context, diff --git a/gilda/resources/ner_stoplist.txt b/gilda/resources/ner_stoplist.txt new file mode 100644 index 0000000..3d5d26a --- /dev/null +++ b/gilda/resources/ner_stoplist.txt @@ -0,0 +1,171 @@ +-I +-II +-III +14 +A-C +ANOVA +Bar +Bark +Bars +Cell +Cells +Control +Ctrl +DNA +Fig +KDKO +Left +Methods +NS +RNA +Right +Rod +SD +SDS-PAGE +SEM +Scott +Student +Table +Task +XREF_BIBR +XREF_FIG +[ +] +acid +age +alpha +andD +animals +ankle +ankles +antibodies +antibody +antigen +area +arrowheads +bar +bark +bars +basal +bean +beta +bi +binding +biological replicates +bite +blot +cell +cells +clones +condition +control +crash +cryptic +culture +cultures +damage +danger +docking +duet +duration +et +experiment +face +fact +fast +fate +feet +fig +figure +finger +fingers +fist +fluorescence +foot +form +gain +gene +genes +genotype +goat +group +growth +hand +hands +head +hip +hips +hr +image +immunoblotting +impact +individual +inhibitor +injury +intensity +ir +knee +knees +lead +left +leg +legs +light +link +links +localization +mM +mark +matrix +media +membrane +microscopy +mitochondrial +nM +neck +net +neurons +nm +one +partial +patients +per +phosphorylation +plasmid +plasmids +post +prey +probe +processes +protein +protein levels +proteins +red +result +right +rod +role +sensor +set +shoulder +shoulders +size +spatial +starvation +strain +task +time +tissue +toe +toes +top +transfection +treatment +tube +type +vs +water +white +wt +µM +µm diff --git a/gilda/tests/test_ner.py b/gilda/tests/test_ner.py index 6c54758..67e0eb1 100644 --- a/gilda/tests/test_ner.py +++ b/gilda/tests/test_ner.py @@ -12,22 +12,19 @@ def test_annotate(): assert isinstance(annotations, list) # Check that we get 7 annotations - assert len(annotations) == 7 + assert len(annotations) == 4 # Check that the annotations are for the expected words assert tuple(a.text for a in annotations) == ( - 'protein', 'BRAF', 'kinase', 'BRAF', 'gene', 'BRAF', 'protein') + 'BRAF', 'kinase', 'BRAF', 'BRAF') # Check that the spans are correct - expected_spans = ((4, 11), (12, 16), (22, 28), (30, 34), (40, 44), - (46, 50), (56, 63)) + expected_spans = ((12, 16), (22, 28), (30, 34), (46, 50)) actual_spans = tuple((a.start, a.end) for a in annotations) assert actual_spans == expected_spans # Check that the curies are correct - expected_curies = ("CHEBI:36080", "hgnc:1097", "mesh:D010770", - "hgnc:1097", "mesh:D005796", "hgnc:1097", - "CHEBI:36080") + expected_curies = ("hgnc:1097", "mesh:D010770", "hgnc:1097", "hgnc:1097") actual_curies = tuple(a.matches[0].term.get_curie() for a in annotations) assert actual_curies == expected_curies @@ -40,20 +37,14 @@ def test_get_brat(): assert isinstance(brat_str, str) match_str = dedent(""" - T1\tEntity 4 11\tprotein - #1\tAnnotatorNotes T1\tCHEBI:36080 - T2\tEntity 12 16\tBRAF - #2\tAnnotatorNotes T2\thgnc:1097 - T3\tEntity 22 28\tkinase - #3\tAnnotatorNotes T3\tmesh:D010770 - T4\tEntity 30 34\tBRAF + T1\tEntity 12 16\tBRAF + #1\tAnnotatorNotes T1\thgnc:1097 + T2\tEntity 22 28\tkinase + #2\tAnnotatorNotes T2\tmesh:D010770 + T3\tEntity 30 34\tBRAF + #3\tAnnotatorNotes T3\thgnc:1097 + T4\tEntity 46 50\tBRAF #4\tAnnotatorNotes T4\thgnc:1097 - T5\tEntity 40 44\tgene - #5\tAnnotatorNotes T5\tmesh:D005796 - T6\tEntity 46 50\tBRAF - #6\tAnnotatorNotes T6\thgnc:1097 - T7\tEntity 56 63\tprotein - #7\tAnnotatorNotes T7\tCHEBI:36080 """).lstrip() assert brat_str == match_str