From 5b258035cb9120ed250b76425c1eefcc68b0ee69 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Wed, 25 Oct 2023 12:34:43 +0900 Subject: [PATCH] add ner extraction on results --- client.py | 225 +++++++++++++++++++++++++++++++++ grobid_client_generic.py | 264 +++++++++++++++++++++++++++++++++++++++ grobid_processors.py | 15 +-- streamlit_app.py | 46 ++++++- 4 files changed, 539 insertions(+), 11 deletions(-) create mode 100644 client.py create mode 100644 grobid_client_generic.py diff --git a/client.py b/client.py new file mode 100644 index 0000000..42f7d19 --- /dev/null +++ b/client.py @@ -0,0 +1,225 @@ +""" Generic API Client """ +from copy import deepcopy +import json +import requests + +try: + from urlparse import urljoin +except ImportError: + from urllib.parse import urljoin + + +class ApiClient(object): + """ Client to interact with a generic Rest API. + + Subclasses should implement functionality accordingly with the provided + service methods, i.e. ``get``, ``post``, ``put`` and ``delete``. + """ + + accept_type = 'application/xml' + api_base = None + + def __init__( + self, + base_url, + username=None, + api_key=None, + status_endpoint=None, + timeout=60 + ): + """ Initialise client. + + Args: + base_url (str): The base URL to the service being used. + username (str): The username to authenticate with. + api_key (str): The API key to authenticate with. + timeout (int): Maximum time before timing out. + """ + self.base_url = base_url + self.username = username + self.api_key = api_key + self.status_endpoint = urljoin(self.base_url, status_endpoint) + self.timeout = timeout + + @staticmethod + def encode(request, data): + """ Add request content data to request body, set Content-type header. + + Should be overridden by subclasses if not using JSON encoding. + + Args: + request (HTTPRequest): The request object. + data (dict, None): Data to be encoded. + + Returns: + HTTPRequest: The request object. + """ + if data is None: + return request + + request.add_header('Content-Type', 'application/json') + request.extracted_data = json.dumps(data) + + return request + + @staticmethod + def decode(response): + """ Decode the returned data in the response. + + Should be overridden by subclasses if something else than JSON is + expected. + + Args: + response (HTTPResponse): The response object. + + Returns: + dict or None. + """ + try: + return response.json() + except ValueError as e: + return e.message + + def get_credentials(self): + """ Returns parameters to be added to authenticate the request. + + This lives on its own to make it easier to re-implement it if needed. + + Returns: + dict: A dictionary containing the credentials. + """ + return {"username": self.username, "api_key": self.api_key} + + def call_api( + self, + method, + url, + headers=None, + params=None, + data=None, + files=None, + timeout=None, + ): + """ Call API. + + This returns object containing data, with error details if applicable. + + Args: + method (str): The HTTP method to use. + url (str): Resource location relative to the base URL. + headers (dict or None): Extra request headers to set. + params (dict or None): Query-string parameters. + data (dict or None): Request body contents for POST or PUT requests. + files (dict or None: Files to be passed to the request. + timeout (int): Maximum time before timing out. + + Returns: + ResultParser or ErrorParser. + """ + headers = deepcopy(headers) or {} + headers['Accept'] = self.accept_type if 'Accept' not in headers else headers['Accept'] + params = deepcopy(params) or {} + data = data or {} + files = files or {} + #if self.username is not None and self.api_key is not None: + # params.update(self.get_credentials()) + r = requests.request( + method, + url, + headers=headers, + params=params, + files=files, + data=data, + timeout=timeout, + ) + + return r, r.status_code + + def get(self, url, params=None, **kwargs): + """ Call the API with a GET request. + + Args: + url (str): Resource location relative to the base URL. + params (dict or None): Query-string parameters. + + Returns: + ResultParser or ErrorParser. + """ + return self.call_api( + "GET", + url, + params=params, + **kwargs + ) + + def delete(self, url, params=None, **kwargs): + """ Call the API with a DELETE request. + + Args: + url (str): Resource location relative to the base URL. + params (dict or None): Query-string parameters. + + Returns: + ResultParser or ErrorParser. + """ + return self.call_api( + "DELETE", + url, + params=params, + **kwargs + ) + + def put(self, url, params=None, data=None, files=None, **kwargs): + """ Call the API with a PUT request. + + Args: + url (str): Resource location relative to the base URL. + params (dict or None): Query-string parameters. + data (dict or None): Request body contents. + files (dict or None: Files to be passed to the request. + + Returns: + An instance of ResultParser or ErrorParser. + """ + return self.call_api( + "PUT", + url, + params=params, + data=data, + files=files, + **kwargs + ) + + def post(self, url, params=None, data=None, files=None, **kwargs): + """ Call the API with a POST request. + + Args: + url (str): Resource location relative to the base URL. + params (dict or None): Query-string parameters. + data (dict or None): Request body contents. + files (dict or None: Files to be passed to the request. + + Returns: + An instance of ResultParser or ErrorParser. + """ + return self.call_api( + method="POST", + url=url, + params=params, + data=data, + files=files, + **kwargs + ) + + def service_status(self, **kwargs): + """ Call the API to get the status of the service. + + Returns: + An instance of ResultParser or ErrorParser. + """ + return self.call_api( + 'GET', + self.status_endpoint, + params={'format': 'json'}, + **kwargs + ) diff --git a/grobid_client_generic.py b/grobid_client_generic.py new file mode 100644 index 0000000..c06acea --- /dev/null +++ b/grobid_client_generic.py @@ -0,0 +1,264 @@ +import json +import os +import time + +import requests +import yaml + +from commons.client import ApiClient + +''' +This client is a generic client for any Grobid application and sub-modules. +At the moment, it supports only single document processing. + +Source: https://github.com/kermitt2/grobid-client-python +''' + + +class GrobidClientGeneric(ApiClient): + + def __init__(self, config_path=None, ping=False): + self.config = None + if config_path is not None: + self.config = self.load_yaml_config_from_file(path=config_path) + super().__init__(self.config['grobid']['server']) + + if ping: + result = self.ping_grobid() + if not result: + raise Exception("Grobid is down.") + + os.environ['NO_PROXY'] = "nims.go.jp" + + @staticmethod + def load_json_config_from_file(self, path='./config.json', ping=False): + """ + Load the json configuration + """ + config = {} + with open(path, 'r') as fp: + config = json.load(fp) + + if ping: + result = self.ping_grobid() + if not result: + raise Exception("Grobid is down.") + + return config + + def load_yaml_config_from_file(self, path='./config.yaml'): + """ + Load the YAML configuration + """ + config = {} + try: + with open(path, 'r') as the_file: + raw_configuration = the_file.read() + + config = yaml.safe_load(raw_configuration) + except Exception as e: + print("Configuration could not be loaded: ", str(e)) + exit(1) + + return config + + def set_config(self, config, ping=False): + self.config = config + if ping: + try: + result = self.ping_grobid() + if not result: + raise Exception("Grobid is down.") + except Exception as e: + raise Exception("Grobid is down or other problems were encountered. ", e) + + def ping_grobid(self): + # test if the server is up and running... + ping_url = self.get_grobid_url("ping") + + r = requests.get(ping_url) + status = r.status_code + + if status != 200: + print('GROBID server does not appear up and running ' + str(status)) + return False + else: + print("GROBID server is up and running") + return True + + def get_grobid_url(self, action): + grobid_config = self.config['grobid'] + base_url = grobid_config['server'] + action_url = base_url + grobid_config['url_mapping'][action] + + return action_url + + def process_texts(self, input, method_name='superconductors', params={}, headers={"Accept": "application/json"}): + + files = { + 'texts': input + } + + the_url = self.get_grobid_url(method_name) + params, the_url = self.get_params_from_url(the_url) + + res, status = self.post( + url=the_url, + files=files, + data=params, + headers=headers + ) + + if status == 503: + time.sleep(self.config['sleep_time']) + return self.process_texts(input, method_name, params, headers) + elif status != 200: + print('Processing failed with error ' + str(status)) + return status, None + else: + return status, json.loads(res.text) + + def process_text(self, input, method_name='superconductors', params={}, headers={"Accept": "application/json"}): + + files = { + 'text': input + } + + the_url = self.get_grobid_url(method_name) + params, the_url = self.get_params_from_url(the_url) + + res, status = self.post( + url=the_url, + files=files, + data=params, + headers=headers + ) + + if status == 503: + time.sleep(self.config['sleep_time']) + return self.process_text(input, method_name, params, headers) + elif status != 200: + print('Processing failed with error ' + str(status)) + return status, None + else: + return status, json.loads(res.text) + + def process(self, form_data: dict, method_name='superconductors', params={}, headers={"Accept": "application/json"}): + + the_url = self.get_grobid_url(method_name) + params, the_url = self.get_params_from_url(the_url) + + res, status = self.post( + url=the_url, + files=form_data, + data=params, + headers=headers + ) + + if status == 503: + time.sleep(self.config['sleep_time']) + return self.process_text(input, method_name, params, headers) + elif status != 200: + print('Processing failed with error ' + str(status)) + else: + return res.text + + def process_pdf_batch(self, pdf_files, params={}): + pass + + def process_pdf(self, pdf_file, method_name, params={}, headers={"Accept": "application/json"}, verbose=False, + retry=None): + + files = { + 'input': ( + pdf_file, + open(pdf_file, 'rb'), + 'application/pdf', + {'Expires': '0'} + ) + } + + the_url = self.get_grobid_url(method_name) + + params, the_url = self.get_params_from_url(the_url) + + res, status = self.post( + url=the_url, + files=files, + data=params, + headers=headers + ) + + if status == 503 or status == 429: + if retry is None: + retry = self.config['max_retry'] - 1 + else: + if retry - 1 == 0: + if verbose: + print("re-try exhausted. Aborting request") + return None, status + else: + retry -= 1 + + sleep_time = self.config['sleep_time'] + if verbose: + print("Server is saturated, waiting", sleep_time, "seconds and trying again. ") + time.sleep(sleep_time) + return self.process_pdf(pdf_file, method_name, params, headers, verbose=verbose, retry=retry) + elif status != 200: + desc = None + if res.content: + c = json.loads(res.text) + desc = c['description'] if 'description' in c else None + return desc, status + elif status == 204: + # print('No content returned. Moving on. ') + return None, status + else: + return res.text, status + + def get_params_from_url(self, the_url): + params = {} + if "?" in the_url: + split = the_url.split("?") + the_url = split[0] + params = split[1] + + params = {param.split("=")[0]: param.split("=")[1] for param in params.split("&")} + return params, the_url + + def process_json(self, text, method_name="processJson", params={}, headers={"Accept": "application/json"}, + verbose=False): + files = { + 'input': ( + None, + text, + 'application/json', + {'Expires': '0'} + ) + } + + the_url = self.get_grobid_url(method_name) + + params, the_url = self.get_params_from_url(the_url) + + res, status = self.post( + url=the_url, + files=files, + data=params, + headers=headers + ) + + if status == 503: + time.sleep(self.config['sleep_time']) + return self.process_json(text, method_name, params, headers), status + elif status != 200: + if verbose: + print('Processing failed with error ', status) + return None, status + elif status == 204: + if verbose: + print('No content returned. Moving on. ') + return None, status + else: + return res.text, status diff --git a/grobid_processors.py b/grobid_processors.py index 85b34a8..d87cb25 100644 --- a/grobid_processors.py +++ b/grobid_processors.py @@ -412,7 +412,8 @@ def __init__(self, grobid_superconductors_client): self.grobid_superconductors_client = grobid_superconductors_client def extract_materials(self, text): - status, result = self.grobid_superconductors_client.process_text(text.strip(), "processText_disable_linking") + preprocessed_text = text.strip() + status, result = self.grobid_superconductors_client.process_text(preprocessed_text, "processText_disable_linking") if status != 200: result = {} @@ -420,10 +421,10 @@ def extract_materials(self, text): spans = [] if 'passages' in result: - materials = self.parse_superconductors_output(result, text) + materials = self.parse_superconductors_output(result, preprocessed_text) for m in materials: - item = {"text": text[m['offset_start']:m['offset_end']]} + item = {"text": preprocessed_text[m['offset_start']:m['offset_end']]} item['offset_start'] = m['offset_start'] item['offset_end'] = m['offset_end'] @@ -502,12 +503,12 @@ def parse_superconductors_output(result, original_text): class GrobidAggregationProcessor(GrobidProcessor, GrobidQuantitiesProcessor, GrobidMaterialsProcessor): def __init__(self, grobid_client, grobid_quantities_client=None, grobid_superconductors_client=None): GrobidProcessor.__init__(self, grobid_client) - GrobidQuantitiesProcessor.__init__(self, grobid_quantities_client) - GrobidMaterialsProcessor.__init__(self, grobid_superconductors_client) + self.gqp = GrobidQuantitiesProcessor(grobid_quantities_client) + self.gmp = GrobidMaterialsProcessor(grobid_superconductors_client) def process_single_text(self, text): - extracted_quantities_spans = extract_quantities(self.grobid_quantities_client, text) - extracted_materials_spans = extract_materials(self.grobid_superconductors_client, text) + extracted_quantities_spans = self.gqp.extract_quantities(text) + extracted_materials_spans = self.gmp.extract_materials(text) all_entities = extracted_quantities_spans + extracted_materials_spans entities = self.prune_overlapping_annotations(all_entities) return entities diff --git a/streamlit_app.py b/streamlit_app.py index e8652a9..31f6bde 100644 --- a/streamlit_app.py +++ b/streamlit_app.py @@ -1,8 +1,10 @@ import os +import re from hashlib import blake2b from tempfile import NamedTemporaryFile import dotenv +from grobid_quantities.quantities import QuantitiesAPI from langchain.llms.huggingface_hub import HuggingFaceHub dotenv.load_dotenv(override=True) @@ -12,6 +14,8 @@ from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings from document_qa_engine import DocumentQAEngine +from grobid_processors import GrobidAggregationProcessor, decorate_text_with_annotations +from grobid_client_generic import GrobidClientGeneric if 'rqa' not in st.session_state: st.session_state['rqa'] = None @@ -38,7 +42,6 @@ if "messages" not in st.session_state: st.session_state.messages = [] - def new_file(): st.session_state['loaded_embeddings'] = None st.session_state['doc_id'] = None @@ -66,6 +69,33 @@ def init_qa(model): return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL']) +@st.cache_resource +def init_ner(): + quantities_client = QuantitiesAPI(os.environ['GROBID_QUANTITIES_URL'], check_server=True) + + materials_client = GrobidClientGeneric(ping=True) + config_materials = { + 'grobid': { + "server": os.environ['GROBID_MATERIALS_URL'], + 'sleep_time': 5, + 'timeout': 60, + 'url_mapping': { + 'processText_disable_linking': "/service/process/text?disableLinking=True", + # 'processText_disable_linking': "/service/process/text" + } + } + } + + materials_client.set_config(config_materials) + + gqa = GrobidAggregationProcessor(None, + grobid_quantities_client=quantities_client, + grobid_superconductors_client=materials_client + ) + + return gqa + +gqa = init_ner() def get_file_hash(fname): hash_md5 = blake2b() @@ -84,7 +114,7 @@ def play_old_messages(): elif message['role'] == 'assistant': with st.chat_message("assistant"): if mode == "LLM": - st.markdown(message['content']) + st.markdown(message['content'], unsafe_allow_html=True) else: st.write(message['content']) @@ -168,6 +198,7 @@ def play_old_messages(): chunk_size=250, perc_overlap=0.1) st.session_state['loaded_embeddings'] = True + st.session_state.messages = [] # timestamp = datetime.utcnow() @@ -175,7 +206,7 @@ def play_old_messages(): for message in st.session_state.messages: with st.chat_message(message["role"]): if message['mode'] == "LLM": - st.markdown(message["content"]) + st.markdown(message["content"], unsafe_allow_html=True) elif message['mode'] == "Embeddings": st.write(message["content"]) @@ -196,7 +227,14 @@ def play_old_messages(): with st.chat_message("assistant"): if mode == "LLM": - st.markdown(text_response) + entities = gqa.process_single_text(text_response) + # for entity in entities: + # entity + decorated_text = decorate_text_with_annotations(text_response.strip(), entities) + decorated_text = decorated_text.replace('class="label material"', 'style="color:blue"') + decorated_text = re.sub(r'class="label[^"]+"', 'style="color:yellow"', decorated_text) + st.markdown(decorated_text, unsafe_allow_html=True) + text_response = decorated_text else: st.write(text_response) st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})