diff --git a/README.md b/README.md
index 3635257..2dd8458 100644
--- a/README.md
+++ b/README.md
@@ -12,13 +12,15 @@ license: apache-2.0
# DocumentIQA: Scientific Document Insight QA
+**Work in progress** :construction_worker:
+
## Introduction
Question/Answering on scientific documents using LLMs (OpenAI, Mistral, ~~LLama2,~~ etc..).
This application is the frontend for testing the RAG (Retrieval Augmented Generation) on scientific documents, that we are developing at NIMS.
-Differently to most of the project, we focus on scientific articles and we are using [Grobid](https://github.com/kermitt2/grobid) for text extraction instead of the raw PDF2Text converter (which is comparable with most of other solutions) allow to extract only full-text.
+Differently to most of the project, we focus on scientific articles. We target only the full-text using [Grobid](https://github.com/kermitt2/grobid) that provide and cleaner results than the raw PDF2Text converter (which is comparable with most of other solutions).
-**Work in progress**
+**NER in LLM response**: The responses from the LLMs are post-processed to extract physical quantities, measurements (with [grobid-quantities](https://github.com/kermitt2/grobid-quantities)) and materials mentions (with [grobid-superconductors](https://github.com/lfoppiano/grobid-superconductors)).
**Demos**:
- (on HuggingFace spaces): https://lfoppiano-document-qa.hf.space/
@@ -31,7 +33,7 @@ Differently to most of the project, we focus on scientific articles and we are u
- Upload a scientific article as PDF document. You will see a spinner or loading indicator while the processing is in progress.
- Once the spinner stops, you can proceed to ask your questions
- ![screenshot1.png](docs%2Fimages%2Fscreenshot1.png)
+ ![screenshot2.png](docs%2Fimages%2Fscreenshot2.png)
### Options
#### Context size
diff --git a/client.py b/client.py
new file mode 100644
index 0000000..42f7d19
--- /dev/null
+++ b/client.py
@@ -0,0 +1,225 @@
+""" Generic API Client """
+from copy import deepcopy
+import json
+import requests
+
+try:
+ from urlparse import urljoin
+except ImportError:
+ from urllib.parse import urljoin
+
+
+class ApiClient(object):
+ """ Client to interact with a generic Rest API.
+
+ Subclasses should implement functionality accordingly with the provided
+ service methods, i.e. ``get``, ``post``, ``put`` and ``delete``.
+ """
+
+ accept_type = 'application/xml'
+ api_base = None
+
+ def __init__(
+ self,
+ base_url,
+ username=None,
+ api_key=None,
+ status_endpoint=None,
+ timeout=60
+ ):
+ """ Initialise client.
+
+ Args:
+ base_url (str): The base URL to the service being used.
+ username (str): The username to authenticate with.
+ api_key (str): The API key to authenticate with.
+ timeout (int): Maximum time before timing out.
+ """
+ self.base_url = base_url
+ self.username = username
+ self.api_key = api_key
+ self.status_endpoint = urljoin(self.base_url, status_endpoint)
+ self.timeout = timeout
+
+ @staticmethod
+ def encode(request, data):
+ """ Add request content data to request body, set Content-type header.
+
+ Should be overridden by subclasses if not using JSON encoding.
+
+ Args:
+ request (HTTPRequest): The request object.
+ data (dict, None): Data to be encoded.
+
+ Returns:
+ HTTPRequest: The request object.
+ """
+ if data is None:
+ return request
+
+ request.add_header('Content-Type', 'application/json')
+ request.extracted_data = json.dumps(data)
+
+ return request
+
+ @staticmethod
+ def decode(response):
+ """ Decode the returned data in the response.
+
+ Should be overridden by subclasses if something else than JSON is
+ expected.
+
+ Args:
+ response (HTTPResponse): The response object.
+
+ Returns:
+ dict or None.
+ """
+ try:
+ return response.json()
+ except ValueError as e:
+ return e.message
+
+ def get_credentials(self):
+ """ Returns parameters to be added to authenticate the request.
+
+ This lives on its own to make it easier to re-implement it if needed.
+
+ Returns:
+ dict: A dictionary containing the credentials.
+ """
+ return {"username": self.username, "api_key": self.api_key}
+
+ def call_api(
+ self,
+ method,
+ url,
+ headers=None,
+ params=None,
+ data=None,
+ files=None,
+ timeout=None,
+ ):
+ """ Call API.
+
+ This returns object containing data, with error details if applicable.
+
+ Args:
+ method (str): The HTTP method to use.
+ url (str): Resource location relative to the base URL.
+ headers (dict or None): Extra request headers to set.
+ params (dict or None): Query-string parameters.
+ data (dict or None): Request body contents for POST or PUT requests.
+ files (dict or None: Files to be passed to the request.
+ timeout (int): Maximum time before timing out.
+
+ Returns:
+ ResultParser or ErrorParser.
+ """
+ headers = deepcopy(headers) or {}
+ headers['Accept'] = self.accept_type if 'Accept' not in headers else headers['Accept']
+ params = deepcopy(params) or {}
+ data = data or {}
+ files = files or {}
+ #if self.username is not None and self.api_key is not None:
+ # params.update(self.get_credentials())
+ r = requests.request(
+ method,
+ url,
+ headers=headers,
+ params=params,
+ files=files,
+ data=data,
+ timeout=timeout,
+ )
+
+ return r, r.status_code
+
+ def get(self, url, params=None, **kwargs):
+ """ Call the API with a GET request.
+
+ Args:
+ url (str): Resource location relative to the base URL.
+ params (dict or None): Query-string parameters.
+
+ Returns:
+ ResultParser or ErrorParser.
+ """
+ return self.call_api(
+ "GET",
+ url,
+ params=params,
+ **kwargs
+ )
+
+ def delete(self, url, params=None, **kwargs):
+ """ Call the API with a DELETE request.
+
+ Args:
+ url (str): Resource location relative to the base URL.
+ params (dict or None): Query-string parameters.
+
+ Returns:
+ ResultParser or ErrorParser.
+ """
+ return self.call_api(
+ "DELETE",
+ url,
+ params=params,
+ **kwargs
+ )
+
+ def put(self, url, params=None, data=None, files=None, **kwargs):
+ """ Call the API with a PUT request.
+
+ Args:
+ url (str): Resource location relative to the base URL.
+ params (dict or None): Query-string parameters.
+ data (dict or None): Request body contents.
+ files (dict or None: Files to be passed to the request.
+
+ Returns:
+ An instance of ResultParser or ErrorParser.
+ """
+ return self.call_api(
+ "PUT",
+ url,
+ params=params,
+ data=data,
+ files=files,
+ **kwargs
+ )
+
+ def post(self, url, params=None, data=None, files=None, **kwargs):
+ """ Call the API with a POST request.
+
+ Args:
+ url (str): Resource location relative to the base URL.
+ params (dict or None): Query-string parameters.
+ data (dict or None): Request body contents.
+ files (dict or None: Files to be passed to the request.
+
+ Returns:
+ An instance of ResultParser or ErrorParser.
+ """
+ return self.call_api(
+ method="POST",
+ url=url,
+ params=params,
+ data=data,
+ files=files,
+ **kwargs
+ )
+
+ def service_status(self, **kwargs):
+ """ Call the API to get the status of the service.
+
+ Returns:
+ An instance of ResultParser or ErrorParser.
+ """
+ return self.call_api(
+ 'GET',
+ self.status_endpoint,
+ params={'format': 'json'},
+ **kwargs
+ )
diff --git a/docs/images/screenshot2.png b/docs/images/screenshot2.png
new file mode 100644
index 0000000..843c8c8
Binary files /dev/null and b/docs/images/screenshot2.png differ
diff --git a/grobid_client_generic.py b/grobid_client_generic.py
new file mode 100644
index 0000000..c06acea
--- /dev/null
+++ b/grobid_client_generic.py
@@ -0,0 +1,264 @@
+import json
+import os
+import time
+
+import requests
+import yaml
+
+from commons.client import ApiClient
+
+'''
+This client is a generic client for any Grobid application and sub-modules.
+At the moment, it supports only single document processing.
+
+Source: https://github.com/kermitt2/grobid-client-python
+'''
+
+
+class GrobidClientGeneric(ApiClient):
+
+ def __init__(self, config_path=None, ping=False):
+ self.config = None
+ if config_path is not None:
+ self.config = self.load_yaml_config_from_file(path=config_path)
+ super().__init__(self.config['grobid']['server'])
+
+ if ping:
+ result = self.ping_grobid()
+ if not result:
+ raise Exception("Grobid is down.")
+
+ os.environ['NO_PROXY'] = "nims.go.jp"
+
+ @staticmethod
+ def load_json_config_from_file(self, path='./config.json', ping=False):
+ """
+ Load the json configuration
+ """
+ config = {}
+ with open(path, 'r') as fp:
+ config = json.load(fp)
+
+ if ping:
+ result = self.ping_grobid()
+ if not result:
+ raise Exception("Grobid is down.")
+
+ return config
+
+ def load_yaml_config_from_file(self, path='./config.yaml'):
+ """
+ Load the YAML configuration
+ """
+ config = {}
+ try:
+ with open(path, 'r') as the_file:
+ raw_configuration = the_file.read()
+
+ config = yaml.safe_load(raw_configuration)
+ except Exception as e:
+ print("Configuration could not be loaded: ", str(e))
+ exit(1)
+
+ return config
+
+ def set_config(self, config, ping=False):
+ self.config = config
+ if ping:
+ try:
+ result = self.ping_grobid()
+ if not result:
+ raise Exception("Grobid is down.")
+ except Exception as e:
+ raise Exception("Grobid is down or other problems were encountered. ", e)
+
+ def ping_grobid(self):
+ # test if the server is up and running...
+ ping_url = self.get_grobid_url("ping")
+
+ r = requests.get(ping_url)
+ status = r.status_code
+
+ if status != 200:
+ print('GROBID server does not appear up and running ' + str(status))
+ return False
+ else:
+ print("GROBID server is up and running")
+ return True
+
+ def get_grobid_url(self, action):
+ grobid_config = self.config['grobid']
+ base_url = grobid_config['server']
+ action_url = base_url + grobid_config['url_mapping'][action]
+
+ return action_url
+
+ def process_texts(self, input, method_name='superconductors', params={}, headers={"Accept": "application/json"}):
+
+ files = {
+ 'texts': input
+ }
+
+ the_url = self.get_grobid_url(method_name)
+ params, the_url = self.get_params_from_url(the_url)
+
+ res, status = self.post(
+ url=the_url,
+ files=files,
+ data=params,
+ headers=headers
+ )
+
+ if status == 503:
+ time.sleep(self.config['sleep_time'])
+ return self.process_texts(input, method_name, params, headers)
+ elif status != 200:
+ print('Processing failed with error ' + str(status))
+ return status, None
+ else:
+ return status, json.loads(res.text)
+
+ def process_text(self, input, method_name='superconductors', params={}, headers={"Accept": "application/json"}):
+
+ files = {
+ 'text': input
+ }
+
+ the_url = self.get_grobid_url(method_name)
+ params, the_url = self.get_params_from_url(the_url)
+
+ res, status = self.post(
+ url=the_url,
+ files=files,
+ data=params,
+ headers=headers
+ )
+
+ if status == 503:
+ time.sleep(self.config['sleep_time'])
+ return self.process_text(input, method_name, params, headers)
+ elif status != 200:
+ print('Processing failed with error ' + str(status))
+ return status, None
+ else:
+ return status, json.loads(res.text)
+
+ def process(self, form_data: dict, method_name='superconductors', params={}, headers={"Accept": "application/json"}):
+
+ the_url = self.get_grobid_url(method_name)
+ params, the_url = self.get_params_from_url(the_url)
+
+ res, status = self.post(
+ url=the_url,
+ files=form_data,
+ data=params,
+ headers=headers
+ )
+
+ if status == 503:
+ time.sleep(self.config['sleep_time'])
+ return self.process_text(input, method_name, params, headers)
+ elif status != 200:
+ print('Processing failed with error ' + str(status))
+ else:
+ return res.text
+
+ def process_pdf_batch(self, pdf_files, params={}):
+ pass
+
+ def process_pdf(self, pdf_file, method_name, params={}, headers={"Accept": "application/json"}, verbose=False,
+ retry=None):
+
+ files = {
+ 'input': (
+ pdf_file,
+ open(pdf_file, 'rb'),
+ 'application/pdf',
+ {'Expires': '0'}
+ )
+ }
+
+ the_url = self.get_grobid_url(method_name)
+
+ params, the_url = self.get_params_from_url(the_url)
+
+ res, status = self.post(
+ url=the_url,
+ files=files,
+ data=params,
+ headers=headers
+ )
+
+ if status == 503 or status == 429:
+ if retry is None:
+ retry = self.config['max_retry'] - 1
+ else:
+ if retry - 1 == 0:
+ if verbose:
+ print("re-try exhausted. Aborting request")
+ return None, status
+ else:
+ retry -= 1
+
+ sleep_time = self.config['sleep_time']
+ if verbose:
+ print("Server is saturated, waiting", sleep_time, "seconds and trying again. ")
+ time.sleep(sleep_time)
+ return self.process_pdf(pdf_file, method_name, params, headers, verbose=verbose, retry=retry)
+ elif status != 200:
+ desc = None
+ if res.content:
+ c = json.loads(res.text)
+ desc = c['description'] if 'description' in c else None
+ return desc, status
+ elif status == 204:
+ # print('No content returned. Moving on. ')
+ return None, status
+ else:
+ return res.text, status
+
+ def get_params_from_url(self, the_url):
+ params = {}
+ if "?" in the_url:
+ split = the_url.split("?")
+ the_url = split[0]
+ params = split[1]
+
+ params = {param.split("=")[0]: param.split("=")[1] for param in params.split("&")}
+ return params, the_url
+
+ def process_json(self, text, method_name="processJson", params={}, headers={"Accept": "application/json"},
+ verbose=False):
+ files = {
+ 'input': (
+ None,
+ text,
+ 'application/json',
+ {'Expires': '0'}
+ )
+ }
+
+ the_url = self.get_grobid_url(method_name)
+
+ params, the_url = self.get_params_from_url(the_url)
+
+ res, status = self.post(
+ url=the_url,
+ files=files,
+ data=params,
+ headers=headers
+ )
+
+ if status == 503:
+ time.sleep(self.config['sleep_time'])
+ return self.process_json(text, method_name, params, headers), status
+ elif status != 200:
+ if verbose:
+ print('Processing failed with error ', status)
+ return None, status
+ elif status == 204:
+ if verbose:
+ print('No content returned. Moving on. ')
+ return None, status
+ else:
+ return res.text, status
diff --git a/grobid_processors.py b/grobid_processors.py
index 85b34a8..d87cb25 100644
--- a/grobid_processors.py
+++ b/grobid_processors.py
@@ -412,7 +412,8 @@ def __init__(self, grobid_superconductors_client):
self.grobid_superconductors_client = grobid_superconductors_client
def extract_materials(self, text):
- status, result = self.grobid_superconductors_client.process_text(text.strip(), "processText_disable_linking")
+ preprocessed_text = text.strip()
+ status, result = self.grobid_superconductors_client.process_text(preprocessed_text, "processText_disable_linking")
if status != 200:
result = {}
@@ -420,10 +421,10 @@ def extract_materials(self, text):
spans = []
if 'passages' in result:
- materials = self.parse_superconductors_output(result, text)
+ materials = self.parse_superconductors_output(result, preprocessed_text)
for m in materials:
- item = {"text": text[m['offset_start']:m['offset_end']]}
+ item = {"text": preprocessed_text[m['offset_start']:m['offset_end']]}
item['offset_start'] = m['offset_start']
item['offset_end'] = m['offset_end']
@@ -502,12 +503,12 @@ def parse_superconductors_output(result, original_text):
class GrobidAggregationProcessor(GrobidProcessor, GrobidQuantitiesProcessor, GrobidMaterialsProcessor):
def __init__(self, grobid_client, grobid_quantities_client=None, grobid_superconductors_client=None):
GrobidProcessor.__init__(self, grobid_client)
- GrobidQuantitiesProcessor.__init__(self, grobid_quantities_client)
- GrobidMaterialsProcessor.__init__(self, grobid_superconductors_client)
+ self.gqp = GrobidQuantitiesProcessor(grobid_quantities_client)
+ self.gmp = GrobidMaterialsProcessor(grobid_superconductors_client)
def process_single_text(self, text):
- extracted_quantities_spans = extract_quantities(self.grobid_quantities_client, text)
- extracted_materials_spans = extract_materials(self.grobid_superconductors_client, text)
+ extracted_quantities_spans = self.gqp.extract_quantities(text)
+ extracted_materials_spans = self.gmp.extract_materials(text)
all_entities = extracted_quantities_spans + extracted_materials_spans
entities = self.prune_overlapping_annotations(all_entities)
return entities
diff --git a/streamlit_app.py b/streamlit_app.py
index e8652a9..ae60106 100644
--- a/streamlit_app.py
+++ b/streamlit_app.py
@@ -1,8 +1,10 @@
import os
+import re
from hashlib import blake2b
from tempfile import NamedTemporaryFile
import dotenv
+from grobid_quantities.quantities import QuantitiesAPI
from langchain.llms.huggingface_hub import HuggingFaceHub
dotenv.load_dotenv(override=True)
@@ -12,6 +14,8 @@
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings
from document_qa_engine import DocumentQAEngine
+from grobid_processors import GrobidAggregationProcessor, decorate_text_with_annotations
+from grobid_client_generic import GrobidClientGeneric
if 'rqa' not in st.session_state:
st.session_state['rqa'] = None
@@ -38,7 +42,6 @@
if "messages" not in st.session_state:
st.session_state.messages = []
-
def new_file():
st.session_state['loaded_embeddings'] = None
st.session_state['doc_id'] = None
@@ -66,6 +69,33 @@ def init_qa(model):
return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
+@st.cache_resource
+def init_ner():
+ quantities_client = QuantitiesAPI(os.environ['GROBID_QUANTITIES_URL'], check_server=True)
+
+ materials_client = GrobidClientGeneric(ping=True)
+ config_materials = {
+ 'grobid': {
+ "server": os.environ['GROBID_MATERIALS_URL'],
+ 'sleep_time': 5,
+ 'timeout': 60,
+ 'url_mapping': {
+ 'processText_disable_linking': "/service/process/text?disableLinking=True",
+ # 'processText_disable_linking': "/service/process/text"
+ }
+ }
+ }
+
+ materials_client.set_config(config_materials)
+
+ gqa = GrobidAggregationProcessor(None,
+ grobid_quantities_client=quantities_client,
+ grobid_superconductors_client=materials_client
+ )
+
+ return gqa
+
+gqa = init_ner()
def get_file_hash(fname):
hash_md5 = blake2b()
@@ -84,7 +114,7 @@ def play_old_messages():
elif message['role'] == 'assistant':
with st.chat_message("assistant"):
if mode == "LLM":
- st.markdown(message['content'])
+ st.markdown(message['content'], unsafe_allow_html=True)
else:
st.write(message['content'])
@@ -147,6 +177,7 @@ def play_old_messages():
st.markdown(
"""After entering your API Key (Open AI or Huggingface). Upload a scientific article as PDF document. You will see a spinner or loading indicator while the processing is in progress. Once the spinner stops, you can proceed to ask your questions.""")
+ st.markdown('**NER on LLM responses**: The responses from the LLMs are post-processed to extract physical quantities, measurements and materials mentions.', unsafe_allow_html=True)
if st.session_state['git_rev'] != "unknown":
st.markdown("**Revision number**: [" + st.session_state[
'git_rev'] + "](https://github.com/lfoppiano/document-qa/commit/" + st.session_state['git_rev'] + ")")
@@ -168,6 +199,7 @@ def play_old_messages():
chunk_size=250,
perc_overlap=0.1)
st.session_state['loaded_embeddings'] = True
+ st.session_state.messages = []
# timestamp = datetime.utcnow()
@@ -175,7 +207,7 @@ def play_old_messages():
for message in st.session_state.messages:
with st.chat_message(message["role"]):
if message['mode'] == "LLM":
- st.markdown(message["content"])
+ st.markdown(message["content"], unsafe_allow_html=True)
elif message['mode'] == "Embeddings":
st.write(message["content"])
@@ -196,7 +228,14 @@ def play_old_messages():
with st.chat_message("assistant"):
if mode == "LLM":
- st.markdown(text_response)
+ entities = gqa.process_single_text(text_response)
+ # for entity in entities:
+ # entity
+ decorated_text = decorate_text_with_annotations(text_response.strip(), entities)
+ decorated_text = decorated_text.replace('class="label material"', 'style="color:green"')
+ decorated_text = re.sub(r'class="label[^"]+"', 'style="color:orange"', decorated_text)
+ st.markdown(decorated_text, unsafe_allow_html=True)
+ text_response = decorated_text
else:
st.write(text_response)
st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})