From eedfb7345f2d01eec760e3a658b578f3291f9ce3 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Fri, 1 Mar 2024 10:49:17 +0900 Subject: [PATCH 01/11] refactoring grobid client generic --- client.py | 225 --------------- document_qa/document_qa_engine.py | 24 +- document_qa/grobid_processors.py | 11 +- document_qa/ner_client_generic.py | 461 ++++++++++++++++++++++++++++++ grobid_client_generic.py | 264 ----------------- streamlit_app.py | 2 +- 6 files changed, 486 insertions(+), 501 deletions(-) delete mode 100644 client.py create mode 100644 document_qa/ner_client_generic.py delete mode 100644 grobid_client_generic.py diff --git a/client.py b/client.py deleted file mode 100644 index 42f7d19..0000000 --- a/client.py +++ /dev/null @@ -1,225 +0,0 @@ -""" Generic API Client """ -from copy import deepcopy -import json -import requests - -try: - from urlparse import urljoin -except ImportError: - from urllib.parse import urljoin - - -class ApiClient(object): - """ Client to interact with a generic Rest API. - - Subclasses should implement functionality accordingly with the provided - service methods, i.e. ``get``, ``post``, ``put`` and ``delete``. - """ - - accept_type = 'application/xml' - api_base = None - - def __init__( - self, - base_url, - username=None, - api_key=None, - status_endpoint=None, - timeout=60 - ): - """ Initialise client. - - Args: - base_url (str): The base URL to the service being used. - username (str): The username to authenticate with. - api_key (str): The API key to authenticate with. - timeout (int): Maximum time before timing out. - """ - self.base_url = base_url - self.username = username - self.api_key = api_key - self.status_endpoint = urljoin(self.base_url, status_endpoint) - self.timeout = timeout - - @staticmethod - def encode(request, data): - """ Add request content data to request body, set Content-type header. - - Should be overridden by subclasses if not using JSON encoding. - - Args: - request (HTTPRequest): The request object. - data (dict, None): Data to be encoded. - - Returns: - HTTPRequest: The request object. - """ - if data is None: - return request - - request.add_header('Content-Type', 'application/json') - request.extracted_data = json.dumps(data) - - return request - - @staticmethod - def decode(response): - """ Decode the returned data in the response. - - Should be overridden by subclasses if something else than JSON is - expected. - - Args: - response (HTTPResponse): The response object. - - Returns: - dict or None. - """ - try: - return response.json() - except ValueError as e: - return e.message - - def get_credentials(self): - """ Returns parameters to be added to authenticate the request. - - This lives on its own to make it easier to re-implement it if needed. - - Returns: - dict: A dictionary containing the credentials. - """ - return {"username": self.username, "api_key": self.api_key} - - def call_api( - self, - method, - url, - headers=None, - params=None, - data=None, - files=None, - timeout=None, - ): - """ Call API. - - This returns object containing data, with error details if applicable. - - Args: - method (str): The HTTP method to use. - url (str): Resource location relative to the base URL. - headers (dict or None): Extra request headers to set. - params (dict or None): Query-string parameters. - data (dict or None): Request body contents for POST or PUT requests. - files (dict or None: Files to be passed to the request. - timeout (int): Maximum time before timing out. - - Returns: - ResultParser or ErrorParser. - """ - headers = deepcopy(headers) or {} - headers['Accept'] = self.accept_type if 'Accept' not in headers else headers['Accept'] - params = deepcopy(params) or {} - data = data or {} - files = files or {} - #if self.username is not None and self.api_key is not None: - # params.update(self.get_credentials()) - r = requests.request( - method, - url, - headers=headers, - params=params, - files=files, - data=data, - timeout=timeout, - ) - - return r, r.status_code - - def get(self, url, params=None, **kwargs): - """ Call the API with a GET request. - - Args: - url (str): Resource location relative to the base URL. - params (dict or None): Query-string parameters. - - Returns: - ResultParser or ErrorParser. - """ - return self.call_api( - "GET", - url, - params=params, - **kwargs - ) - - def delete(self, url, params=None, **kwargs): - """ Call the API with a DELETE request. - - Args: - url (str): Resource location relative to the base URL. - params (dict or None): Query-string parameters. - - Returns: - ResultParser or ErrorParser. - """ - return self.call_api( - "DELETE", - url, - params=params, - **kwargs - ) - - def put(self, url, params=None, data=None, files=None, **kwargs): - """ Call the API with a PUT request. - - Args: - url (str): Resource location relative to the base URL. - params (dict or None): Query-string parameters. - data (dict or None): Request body contents. - files (dict or None: Files to be passed to the request. - - Returns: - An instance of ResultParser or ErrorParser. - """ - return self.call_api( - "PUT", - url, - params=params, - data=data, - files=files, - **kwargs - ) - - def post(self, url, params=None, data=None, files=None, **kwargs): - """ Call the API with a POST request. - - Args: - url (str): Resource location relative to the base URL. - params (dict or None): Query-string parameters. - data (dict or None): Request body contents. - files (dict or None: Files to be passed to the request. - - Returns: - An instance of ResultParser or ErrorParser. - """ - return self.call_api( - method="POST", - url=url, - params=params, - data=data, - files=files, - **kwargs - ) - - def service_status(self, **kwargs): - """ Call the API to get the status of the service. - - Returns: - An instance of ResultParser or ErrorParser. - """ - return self.call_api( - 'GET', - self.status_endpoint, - params={'format': 'json'}, - **kwargs - ) diff --git a/document_qa/document_qa_engine.py b/document_qa/document_qa_engine.py index b207cbc..1fdf03b 100644 --- a/document_qa/document_qa_engine.py +++ b/document_qa/document_qa_engine.py @@ -85,6 +85,9 @@ def merge_passages(self, passages, chunk_size, tolerance=0.2): return new_passages_struct +class DataStorage: + + class DocumentQAEngine: llm = None @@ -123,16 +126,7 @@ def __init__(self, self.load_embeddings(self.embeddings_root_path) if grobid_url: - self.grobid_url = grobid_url - grobid_client = GrobidClient( - grobid_server=self.grobid_url, - batch_size=1000, - coordinates=["p", "title", "persName"], - sleep_time=5, - timeout=60, - check_server=True - ) - self.grobid_processor = GrobidProcessor(grobid_client) + self.grobid_processor = GrobidProcessor(grobid_url) def load_embeddings(self, embeddings_root_path: Union[str, Path]) -> None: """ @@ -204,6 +198,16 @@ def query_storage(self, query: str, doc_id, context_size=4): context_as_text = [doc.page_content for doc in documents] return context_as_text + def query_storage_and_embeddings(self, query: str, doc_id, context_size=4): + db = self.embeddings_dict[doc_id] + retriever = db.as_retriever(search_kwargs={"k": context_size}) + relevant_documents = retriever.get_relevant_documents(query, include=["embeddings"]) + + context_as_text = [doc.page_content for doc in relevant_documents] + return context_as_text + + # chroma_collection.get(include=['embeddings'])['embeddings'] + def _parse_json(self, response, output_parser): system_message = "You are an useful assistant expert in materials science, physics, and chemistry " \ "that can process text and transform it to JSON." diff --git a/document_qa/grobid_processors.py b/document_qa/grobid_processors.py index a28138b..287b965 100644 --- a/document_qa/grobid_processors.py +++ b/document_qa/grobid_processors.py @@ -6,6 +6,7 @@ import dateparser import grobid_tei_xml from bs4 import BeautifulSoup +from grobid_client.grobid_client import GrobidClient from tqdm import tqdm @@ -127,8 +128,16 @@ def post_process(self, text): class GrobidProcessor(BaseProcessor): - def __init__(self, grobid_client): + def __init__(self, grobid_url, ping_server=True): # super().__init__() + grobid_client = GrobidClient( + grobid_server=grobid_url, + batch_size=5, + coordinates=["p", "title", "persName"], + sleep_time=5, + timeout=60, + check_server=ping_server + ) self.grobid_client = grobid_client def process_structure(self, input_path, coordinates=False): diff --git a/document_qa/ner_client_generic.py b/document_qa/ner_client_generic.py new file mode 100644 index 0000000..fe4b846 --- /dev/null +++ b/document_qa/ner_client_generic.py @@ -0,0 +1,461 @@ +import os +import time + +import yaml + +''' +This client is a generic client for any Grobid application and sub-modules. +At the moment, it supports only single document processing. + +Source: https://github.com/kermitt2/grobid-client-python +''' + +""" Generic API Client """ +from copy import deepcopy +import json +import requests + +try: + from urlparse import urljoin +except ImportError: + from urllib.parse import urljoin + + +class ApiClient(object): + """ Client to interact with a generic Rest API. + + Subclasses should implement functionality accordingly with the provided + service methods, i.e. ``get``, ``post``, ``put`` and ``delete``. + """ + + accept_type = 'application/xml' + api_base = None + + def __init__( + self, + base_url, + username=None, + api_key=None, + status_endpoint=None, + timeout=60 + ): + """ Initialise client. + + Args: + base_url (str): The base URL to the service being used. + username (str): The username to authenticate with. + api_key (str): The API key to authenticate with. + timeout (int): Maximum time before timing out. + """ + self.base_url = base_url + self.username = username + self.api_key = api_key + self.status_endpoint = urljoin(self.base_url, status_endpoint) + self.timeout = timeout + + @staticmethod + def encode(request, data): + """ Add request content data to request body, set Content-type header. + + Should be overridden by subclasses if not using JSON encoding. + + Args: + request (HTTPRequest): The request object. + data (dict, None): Data to be encoded. + + Returns: + HTTPRequest: The request object. + """ + if data is None: + return request + + request.add_header('Content-Type', 'application/json') + request.extracted_data = json.dumps(data) + + return request + + @staticmethod + def decode(response): + """ Decode the returned data in the response. + + Should be overridden by subclasses if something else than JSON is + expected. + + Args: + response (HTTPResponse): The response object. + + Returns: + dict or None. + """ + try: + return response.json() + except ValueError as e: + return e.message + + def get_credentials(self): + """ Returns parameters to be added to authenticate the request. + + This lives on its own to make it easier to re-implement it if needed. + + Returns: + dict: A dictionary containing the credentials. + """ + return {"username": self.username, "api_key": self.api_key} + + def call_api( + self, + method, + url, + headers=None, + params=None, + data=None, + files=None, + timeout=None, + ): + """ Call API. + + This returns object containing data, with error details if applicable. + + Args: + method (str): The HTTP method to use. + url (str): Resource location relative to the base URL. + headers (dict or None): Extra request headers to set. + params (dict or None): Query-string parameters. + data (dict or None): Request body contents for POST or PUT requests. + files (dict or None: Files to be passed to the request. + timeout (int): Maximum time before timing out. + + Returns: + ResultParser or ErrorParser. + """ + headers = deepcopy(headers) or {} + headers['Accept'] = self.accept_type if 'Accept' not in headers else headers['Accept'] + params = deepcopy(params) or {} + data = data or {} + files = files or {} + # if self.username is not None and self.api_key is not None: + # params.update(self.get_credentials()) + r = requests.request( + method, + url, + headers=headers, + params=params, + files=files, + data=data, + timeout=timeout, + ) + + return r, r.status_code + + def get(self, url, params=None, **kwargs): + """ Call the API with a GET request. + + Args: + url (str): Resource location relative to the base URL. + params (dict or None): Query-string parameters. + + Returns: + ResultParser or ErrorParser. + """ + return self.call_api( + "GET", + url, + params=params, + **kwargs + ) + + def delete(self, url, params=None, **kwargs): + """ Call the API with a DELETE request. + + Args: + url (str): Resource location relative to the base URL. + params (dict or None): Query-string parameters. + + Returns: + ResultParser or ErrorParser. + """ + return self.call_api( + "DELETE", + url, + params=params, + **kwargs + ) + + def put(self, url, params=None, data=None, files=None, **kwargs): + """ Call the API with a PUT request. + + Args: + url (str): Resource location relative to the base URL. + params (dict or None): Query-string parameters. + data (dict or None): Request body contents. + files (dict or None: Files to be passed to the request. + + Returns: + An instance of ResultParser or ErrorParser. + """ + return self.call_api( + "PUT", + url, + params=params, + data=data, + files=files, + **kwargs + ) + + def post(self, url, params=None, data=None, files=None, **kwargs): + """ Call the API with a POST request. + + Args: + url (str): Resource location relative to the base URL. + params (dict or None): Query-string parameters. + data (dict or None): Request body contents. + files (dict or None: Files to be passed to the request. + + Returns: + An instance of ResultParser or ErrorParser. + """ + return self.call_api( + method="POST", + url=url, + params=params, + data=data, + files=files, + **kwargs + ) + + def service_status(self, **kwargs): + """ Call the API to get the status of the service. + + Returns: + An instance of ResultParser or ErrorParser. + """ + return self.call_api( + 'GET', + self.status_endpoint, + params={'format': 'json'}, + **kwargs + ) + + +class NERClientGeneric(ApiClient): + + def __init__(self, config_path=None, ping=False): + self.config = None + if config_path is not None: + self.config = self._load_yaml_config_from_file(path=config_path) + super().__init__(self.config['grobid']['server']) + + if ping: + result = self.ping_service() + if not result: + raise Exception("Grobid is down.") + + os.environ['NO_PROXY'] = "nims.go.jp" + + @staticmethod + def _load_json_config_from_file(path='./config.json'): + """ + Load the json configuration + """ + config = {} + with open(path, 'r') as fp: + config = json.load(fp) + + return config + + @staticmethod + def _load_yaml_config_from_file(path='./config.yaml'): + """ + Load the YAML configuration + """ + config = {} + try: + with open(path, 'r') as the_file: + raw_configuration = the_file.read() + + config = yaml.safe_load(raw_configuration) + except Exception as e: + print("Configuration could not be loaded: ", str(e)) + exit(1) + + return config + + def set_config(self, config, ping=False): + self.config = config + if ping: + try: + result = self.ping_service() + if not result: + raise Exception("Grobid is down.") + except Exception as e: + raise Exception("Grobid is down or other problems were encountered. ", e) + + def ping_service(self): + # test if the server is up and running... + ping_url = self.get_url("ping") + + r = requests.get(ping_url) + status = r.status_code + + if status != 200: + print('GROBID server does not appear up and running ' + str(status)) + return False + else: + print("GROBID server is up and running") + return True + + def get_url(self, action): + grobid_config = self.config['grobid'] + base_url = grobid_config['server'] + action_url = base_url + grobid_config['url_mapping'][action] + + return action_url + + def process_texts(self, input, method_name='superconductors', params={}, headers={"Accept": "application/json"}): + + files = { + 'texts': input + } + + the_url = self.get_url(method_name) + params, the_url = self.get_params_from_url(the_url) + + res, status = self.post( + url=the_url, + files=files, + data=params, + headers=headers + ) + + if status == 503: + time.sleep(self.config['sleep_time']) + return self.process_texts(input, method_name, params, headers) + elif status != 200: + print('Processing failed with error ' + str(status)) + return status, None + else: + return status, json.loads(res.text) + + def process_text(self, input, method_name='superconductors', params={}, headers={"Accept": "application/json"}): + + files = { + 'text': input + } + + the_url = self.get_url(method_name) + params, the_url = self.get_params_from_url(the_url) + + res, status = self.post( + url=the_url, + files=files, + data=params, + headers=headers + ) + + if status == 503: + time.sleep(self.config['sleep_time']) + return self.process_text(input, method_name, params, headers) + elif status != 200: + print('Processing failed with error ' + str(status)) + return status, None + else: + return status, json.loads(res.text) + + def process_pdf(self, + form_data: dict, + method_name='superconductors', + params={}, + headers={"Accept": "application/json"} + ): + + the_url = self.get_url(method_name) + params, the_url = self.get_params_from_url(the_url) + + res, status = self.post( + url=the_url, + files=form_data, + data=params, + headers=headers + ) + + if status == 503: + time.sleep(self.config['sleep_time']) + return self.process_text(input, method_name, params, headers) + elif status != 200: + print('Processing failed with error ' + str(status)) + else: + return res.text + + def process_pdfs(self, pdf_files, params={}): + pass + + def process_pdf( + self, + pdf_file, + method_name, + params={}, + headers={"Accept": "application/json"}, + verbose=False, + retry=None + ): + + files = { + 'input': ( + pdf_file, + open(pdf_file, 'rb'), + 'application/pdf', + {'Expires': '0'} + ) + } + + the_url = self.get_url(method_name) + + params, the_url = self.get_params_from_url(the_url) + + res, status = self.post( + url=the_url, + files=files, + data=params, + headers=headers + ) + + if status == 503 or status == 429: + if retry is None: + retry = self.config['max_retry'] - 1 + else: + if retry - 1 == 0: + if verbose: + print("re-try exhausted. Aborting request") + return None, status + else: + retry -= 1 + + sleep_time = self.config['sleep_time'] + if verbose: + print("Server is saturated, waiting", sleep_time, "seconds and trying again. ") + time.sleep(sleep_time) + return self.process_pdf(pdf_file, method_name, params, headers, verbose=verbose, retry=retry) + elif status != 200: + desc = None + if res.content: + c = json.loads(res.text) + desc = c['description'] if 'description' in c else None + return desc, status + elif status == 204: + # print('No content returned. Moving on. ') + return None, status + else: + return res.text, status + + def get_params_from_url(self, the_url): + """ + This method is used to pass to the URL predefined parameters, which are added in the URL format + """ + params = {} + if "?" in the_url: + split = the_url.split("?") + the_url = split[0] + params = split[1] + + params = {param.split("=")[0]: param.split("=")[1] for param in params.split("&")} + return params, the_url diff --git a/grobid_client_generic.py b/grobid_client_generic.py deleted file mode 100644 index 491808c..0000000 --- a/grobid_client_generic.py +++ /dev/null @@ -1,264 +0,0 @@ -import json -import os -import time - -import requests -import yaml - -from client import ApiClient - -''' -This client is a generic client for any Grobid application and sub-modules. -At the moment, it supports only single document processing. - -Source: https://github.com/kermitt2/grobid-client-python -''' - - -class GrobidClientGeneric(ApiClient): - - def __init__(self, config_path=None, ping=False): - self.config = None - if config_path is not None: - self.config = self.load_yaml_config_from_file(path=config_path) - super().__init__(self.config['grobid']['server']) - - if ping: - result = self.ping_grobid() - if not result: - raise Exception("Grobid is down.") - - os.environ['NO_PROXY'] = "nims.go.jp" - - @staticmethod - def load_json_config_from_file(self, path='./config.json', ping=False): - """ - Load the json configuration - """ - config = {} - with open(path, 'r') as fp: - config = json.load(fp) - - if ping: - result = self.ping_grobid() - if not result: - raise Exception("Grobid is down.") - - return config - - def load_yaml_config_from_file(self, path='./config.yaml'): - """ - Load the YAML configuration - """ - config = {} - try: - with open(path, 'r') as the_file: - raw_configuration = the_file.read() - - config = yaml.safe_load(raw_configuration) - except Exception as e: - print("Configuration could not be loaded: ", str(e)) - exit(1) - - return config - - def set_config(self, config, ping=False): - self.config = config - if ping: - try: - result = self.ping_grobid() - if not result: - raise Exception("Grobid is down.") - except Exception as e: - raise Exception("Grobid is down or other problems were encountered. ", e) - - def ping_grobid(self): - # test if the server is up and running... - ping_url = self.get_grobid_url("ping") - - r = requests.get(ping_url) - status = r.status_code - - if status != 200: - print('GROBID server does not appear up and running ' + str(status)) - return False - else: - print("GROBID server is up and running") - return True - - def get_grobid_url(self, action): - grobid_config = self.config['grobid'] - base_url = grobid_config['server'] - action_url = base_url + grobid_config['url_mapping'][action] - - return action_url - - def process_texts(self, input, method_name='superconductors', params={}, headers={"Accept": "application/json"}): - - files = { - 'texts': input - } - - the_url = self.get_grobid_url(method_name) - params, the_url = self.get_params_from_url(the_url) - - res, status = self.post( - url=the_url, - files=files, - data=params, - headers=headers - ) - - if status == 503: - time.sleep(self.config['sleep_time']) - return self.process_texts(input, method_name, params, headers) - elif status != 200: - print('Processing failed with error ' + str(status)) - return status, None - else: - return status, json.loads(res.text) - - def process_text(self, input, method_name='superconductors', params={}, headers={"Accept": "application/json"}): - - files = { - 'text': input - } - - the_url = self.get_grobid_url(method_name) - params, the_url = self.get_params_from_url(the_url) - - res, status = self.post( - url=the_url, - files=files, - data=params, - headers=headers - ) - - if status == 503: - time.sleep(self.config['sleep_time']) - return self.process_text(input, method_name, params, headers) - elif status != 200: - print('Processing failed with error ' + str(status)) - return status, None - else: - return status, json.loads(res.text) - - def process(self, form_data: dict, method_name='superconductors', params={}, headers={"Accept": "application/json"}): - - the_url = self.get_grobid_url(method_name) - params, the_url = self.get_params_from_url(the_url) - - res, status = self.post( - url=the_url, - files=form_data, - data=params, - headers=headers - ) - - if status == 503: - time.sleep(self.config['sleep_time']) - return self.process_text(input, method_name, params, headers) - elif status != 200: - print('Processing failed with error ' + str(status)) - else: - return res.text - - def process_pdf_batch(self, pdf_files, params={}): - pass - - def process_pdf(self, pdf_file, method_name, params={}, headers={"Accept": "application/json"}, verbose=False, - retry=None): - - files = { - 'input': ( - pdf_file, - open(pdf_file, 'rb'), - 'application/pdf', - {'Expires': '0'} - ) - } - - the_url = self.get_grobid_url(method_name) - - params, the_url = self.get_params_from_url(the_url) - - res, status = self.post( - url=the_url, - files=files, - data=params, - headers=headers - ) - - if status == 503 or status == 429: - if retry is None: - retry = self.config['max_retry'] - 1 - else: - if retry - 1 == 0: - if verbose: - print("re-try exhausted. Aborting request") - return None, status - else: - retry -= 1 - - sleep_time = self.config['sleep_time'] - if verbose: - print("Server is saturated, waiting", sleep_time, "seconds and trying again. ") - time.sleep(sleep_time) - return self.process_pdf(pdf_file, method_name, params, headers, verbose=verbose, retry=retry) - elif status != 200: - desc = None - if res.content: - c = json.loads(res.text) - desc = c['description'] if 'description' in c else None - return desc, status - elif status == 204: - # print('No content returned. Moving on. ') - return None, status - else: - return res.text, status - - def get_params_from_url(self, the_url): - params = {} - if "?" in the_url: - split = the_url.split("?") - the_url = split[0] - params = split[1] - - params = {param.split("=")[0]: param.split("=")[1] for param in params.split("&")} - return params, the_url - - def process_json(self, text, method_name="processJson", params={}, headers={"Accept": "application/json"}, - verbose=False): - files = { - 'input': ( - None, - text, - 'application/json', - {'Expires': '0'} - ) - } - - the_url = self.get_grobid_url(method_name) - - params, the_url = self.get_params_from_url(the_url) - - res, status = self.post( - url=the_url, - files=files, - data=params, - headers=headers - ) - - if status == 503: - time.sleep(self.config['sleep_time']) - return self.process_json(text, method_name, params, headers), status - elif status != 200: - if verbose: - print('Processing failed with error ', status) - return None, status - elif status == 204: - if verbose: - print('No content returned. Moving on. ') - return None, status - else: - return res.text, status diff --git a/streamlit_app.py b/streamlit_app.py index a4cf878..daf88fa 100644 --- a/streamlit_app.py +++ b/streamlit_app.py @@ -441,7 +441,7 @@ def generate_color_gradient(num_elements): text_response = None if mode == "Embeddings": with st.spinner("Generating LLM response..."): - text_response = st.session_state['rqa'][model].query_storage(question, st.session_state.doc_id, + text_response = st.session_state['rqa'][model].query_storage_and_embeddings(question, st.session_state.doc_id, context_size=context_size) elif mode == "LLM": with st.spinner("Generating response..."): From f684be727facda87fb9fea3164f18304d559562a Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Mon, 8 Apr 2024 22:19:32 +0900 Subject: [PATCH 02/11] decouple quantities and superconductors --- document_qa/grobid_processors.py | 70 ++++++++------------------------ 1 file changed, 17 insertions(+), 53 deletions(-) diff --git a/document_qa/grobid_processors.py b/document_qa/grobid_processors.py index 287b965..d0c9f9d 100644 --- a/document_qa/grobid_processors.py +++ b/document_qa/grobid_processors.py @@ -7,7 +7,6 @@ import grobid_tei_xml from bs4 import BeautifulSoup from grobid_client.grobid_client import GrobidClient -from tqdm import tqdm def get_span_start(type, title=None): @@ -55,49 +54,6 @@ def decorate_text_with_annotations(text, spans, tag="span"): return annotated_text -def extract_quantities(client, x_all, column_text_index): - # relevant_items = ['magnetic field strength', 'magnetic induction', 'maximum energy product', - # "magnetic flux density", "magnetic flux"] - # property_keywords = ['coercivity', 'remanence'] - - output_data = [] - - for idx, example in tqdm(enumerate(x_all), desc="extract quantities"): - text = example[column_text_index] - spans = GrobidQuantitiesProcessor(client).extract_quantities(text) - - data_record = { - "id": example[0], - "filename": example[1], - "passage_id": example[2], - "text": text, - "spans": spans - } - - output_data.append(data_record) - - return output_data - - -def extract_materials(client, x_all, column_text_index): - output_data = [] - - for idx, example in tqdm(enumerate(x_all), desc="extract materials"): - text = example[column_text_index] - spans = GrobidMaterialsProcessor(client).extract_materials(text) - data_record = { - "id": example[0], - "filename": example[1], - "passage_id": example[2], - "text": text, - "spans": spans - } - - output_data.append(data_record) - - return output_data - - def get_parsed_value_type(quantity): if 'parsedValue' in quantity and 'structure' in quantity['parsedValue']: return quantity['parsedValue']['structure']['type'] @@ -199,7 +155,7 @@ def parse_grobid_xml(self, text, coordinates=False): "subSection": "", "passage_id": "htitle", "coordinates": ";".join([node['coords'] if coordinates and node.has_attr('coords') else "" for node in - blocks_header['authors']]) + blocks_header['authors']]) }) passages.append({ @@ -302,7 +258,7 @@ class GrobidQuantitiesProcessor(BaseProcessor): def __init__(self, grobid_quantities_client): self.grobid_quantities_client = grobid_quantities_client - def extract_quantities(self, text): + def extract_quantities(self, text) -> list: status, result = self.grobid_quantities_client.process_text(text.strip()) if status != 200: @@ -570,11 +526,12 @@ def parse_superconductors_output(result, original_text): return materials -class GrobidAggregationProcessor(GrobidProcessor, GrobidQuantitiesProcessor, GrobidMaterialsProcessor): - def __init__(self, grobid_client, grobid_quantities_client=None, grobid_superconductors_client=None): - GrobidProcessor.__init__(self, grobid_client) - self.gqp = GrobidQuantitiesProcessor(grobid_quantities_client) - self.gmp = GrobidMaterialsProcessor(grobid_superconductors_client) +class GrobidAggregationProcessor(GrobidQuantitiesProcessor, GrobidMaterialsProcessor): + def __init__(self, grobid_quantities_client=None, grobid_superconductors_client=None): + if grobid_quantities_client: + self.gqp = GrobidQuantitiesProcessor(grobid_quantities_client) + if grobid_superconductors_client: + self.gmp = GrobidMaterialsProcessor(grobid_superconductors_client) def process_single_text(self, text): extracted_quantities_spans = self.gqp.extract_quantities(text) @@ -584,10 +541,17 @@ def process_single_text(self, text): return entities def extract_quantities(self, text): - return self.gqp.extract_quantities(text) + if self.gqp: + return self.gqp.extract_quantities(text) + else: + return [] + def extract_materials(self, text): - return self.gmp.extract_materials(text) + if self.gmp: + return self.gmp.extract_materials(text) + else: + return [] @staticmethod def box_to_dict(box, color=None, type=None): From 41ad70ed06b80f1c94a12e073150c5a5d647e34e Mon Sep 17 00:00:00 2001 From: Luca Foppiano <Foppiano.Luca@nims.go.jp> Date: Mon, 8 Apr 2024 22:20:15 +0900 Subject: [PATCH 03/11] return embeddings from storage retrieval --- document_qa/document_qa_engine.py | 372 +++++++++++++++++++++++------- requirements.txt | 22 +- streamlit_app.py | 44 ++-- 3 files changed, 330 insertions(+), 108 deletions(-) diff --git a/document_qa/document_qa_engine.py b/document_qa/document_qa_engine.py index 1fdf03b..97d9618 100644 --- a/document_qa/document_qa_engine.py +++ b/document_qa/document_qa_engine.py @@ -1,23 +1,43 @@ import copy import os from pathlib import Path -from typing import Union, Any +from typing import Union, Any, Optional, List, Dict, Tuple, ClassVar, Collection import tiktoken -from grobid_client.grobid_client import GrobidClient from langchain.chains import create_extraction_chain from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \ map_rerank_prompt from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate from langchain.retrievers import MultiQueryRetriever from langchain.schema import Document -from langchain.vectorstores import Chroma +from langchain_community.vectorstores.chroma import Chroma, DEFAULT_K +from langchain_community.vectorstores.faiss import FAISS +from langchain_core.callbacks import CallbackManagerForRetrieverRun +from langchain_core.utils import xor_args +from langchain_core.vectorstores import VectorStore, VectorStoreRetriever from tqdm import tqdm from document_qa.grobid_processors import GrobidProcessor +def _results_to_docs_scores_and_embeddings(results: Any) -> List[Tuple[Document, float, List[float]]]: + return [ + (Document(page_content=result[0], metadata=result[1] or {}), result[2], result[3]) + for result in zip( + results["documents"][0], + results["metadatas"][0], + results["distances"][0], + results["embeddings"][0], + ) + ] + + class TextMerger: + """ + This class tries to replicate the RecursiveTextSplitter from LangChain, to preserve and merge the + coordinate information from the PDF document. + """ + def __init__(self, model_name=None, encoding_name="gpt2"): if model_name is not None: self.enc = tiktoken.encoding_for_model(model_name) @@ -85,52 +105,187 @@ def merge_passages(self, passages, chunk_size, tolerance=0.2): return new_passages_struct -class DataStorage: +class BaseRetrieval: + def __init__( + self, + persist_directory: Path, + embedding_function + ): + self.embedding_function = embedding_function + self.persist_directory = persist_directory + + +class AdvancedVectorStoreRetriever(VectorStoreRetriever): + allowed_search_types: ClassVar[Collection[str]] = ( + "similarity", + "similarity_score_threshold", + "mmr", + "similarity_with_embeddings" + ) + + def _get_relevant_documents( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun + ) -> List[Document]: + if self.search_type == "similarity": + docs = self.vectorstore.similarity_search(query, **self.search_kwargs) + elif self.search_type == "similarity_score_threshold": + docs_and_similarities = ( + self.vectorstore.similarity_search_with_relevance_scores( + query, **self.search_kwargs + ) + ) + for doc, similarity in docs_and_similarities: + if '__similarity' not in doc.metadata.keys(): + doc.metadata['__similarity'] = similarity + + docs = [doc for doc, _ in docs_and_similarities] + elif self.search_type == "mmr": + docs = self.vectorstore.max_marginal_relevance_search( + query, **self.search_kwargs + ) + elif self.search_type == "similarity_with_embeddings": + docs_scores_and_embeddings = ( + self.vectorstore.advanced_similarity_search( + query, **self.search_kwargs + ) + ) -class DocumentQAEngine: - llm = None - qa_chain_type = None - embedding_function = None + for doc, score, embeddings in docs_scores_and_embeddings: + if '__embeddings' not in doc.metadata.keys(): + doc.metadata['__embeddings'] = embeddings + if '__similarity' not in doc.metadata.keys(): + doc.metadata['__similarity'] = score + + docs = [doc for doc, _, _ in docs_scores_and_embeddings] + else: + raise ValueError(f"search_type of {self.search_type} not allowed.") + return docs + + +class AdvancedVectorStore(VectorStore): + def as_retriever(self, **kwargs: Any) -> AdvancedVectorStoreRetriever: + tags = kwargs.pop("tags", None) or [] + tags.extend(self._get_retriever_tags()) + return AdvancedVectorStoreRetriever(vectorstore=self, **kwargs, tags=tags) + + +class ChromaAdvancedRetrieval(Chroma, AdvancedVectorStore): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + @xor_args(("query_texts", "query_embeddings")) + def __query_collection( + self, + query_texts: Optional[List[str]] = None, + query_embeddings: Optional[List[List[float]]] = None, + n_results: int = 4, + where: Optional[Dict[str, str]] = None, + where_document: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> List[Document]: + """Query the chroma collection.""" + try: + import chromadb # noqa: F401 + except ImportError: + raise ValueError( + "Could not import chromadb python package. " + "Please install it with `pip install chromadb`." + ) + return self._collection.query( + query_texts=query_texts, + query_embeddings=query_embeddings, + n_results=n_results, + where=where, + where_document=where_document, + **kwargs, + ) + + def advanced_similarity_search( + self, + query: str, + k: int = DEFAULT_K, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> [List[Document], float, List[float]]: + docs_scores_and_embeddings = self.similarity_search_with_scores_and_embeddings(query, k, filter=filter) + return docs_scores_and_embeddings + + def similarity_search_with_scores_and_embeddings( + self, + query: str, + k: int = DEFAULT_K, + filter: Optional[Dict[str, str]] = None, + where_document: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float, List[float]]]: + + if self._embedding_function is None: + results = self.__query_collection( + query_texts=[query], + n_results=k, + where=filter, + where_document=where_document, + include=['metadatas', 'documents', 'embeddings', 'distances'] + ) + else: + query_embedding = self._embedding_function.embed_query(query) + results = self.__query_collection( + query_embeddings=[query_embedding], + n_results=k, + where=filter, + where_document=where_document, + include=['metadatas', 'documents', 'embeddings', 'distances'] + ) + + return _results_to_docs_scores_and_embeddings(results) + + +class FAISSAdvancedRetrieval(FAISS): + pass + + +class NER_Retrival(VectorStore): + """ + This class implement a retrieval based on NER models. + This is an alternative retrieval to embeddings that relies on extracted entities. + """ + pass + + +engines = { + 'chroma': ChromaAdvancedRetrieval, + 'faiss': FAISSAdvancedRetrieval, + 'ner': NER_Retrival +} + + +class DataStorage: embeddings_dict = {} embeddings_map_from_md5 = {} embeddings_map_to_md5 = {} - default_prompts = { - 'stuff': stuff_prompt, - 'refine': refine_prompts, - "map_reduce": map_reduce_prompt, - "map_rerank": map_rerank_prompt - } - - def __init__(self, - llm, - embedding_function, - qa_chain_type="stuff", - embeddings_root_path=None, - grobid_url=None, - memory=None - ): + def __init__( + self, + embedding_function, + root_path: Path = None, + engine=ChromaAdvancedRetrieval, + ) -> None: + self.root_path = root_path + self.engine = engine self.embedding_function = embedding_function - self.llm = llm - self.memory = memory - self.chain = load_qa_chain(llm, chain_type=qa_chain_type) - self.text_merger = TextMerger() - if embeddings_root_path is not None: - self.embeddings_root_path = embeddings_root_path - if not os.path.exists(embeddings_root_path): - os.makedirs(embeddings_root_path) + if root_path is not None: + self.embeddings_root_path = root_path + if not os.path.exists(root_path): + os.makedirs(root_path) else: self.load_embeddings(self.embeddings_root_path) - if grobid_url: - self.grobid_processor = GrobidProcessor(grobid_url) - def load_embeddings(self, embeddings_root_path: Union[str, Path]) -> None: """ - Load the embeddings assuming they are all persisted and stored in a single directory. + Load the vector storage assuming they are all persisted and stored in a single directory. The root path of the embeddings containing one data store for each document in each subdirectory """ @@ -141,8 +296,10 @@ def load_embeddings(self, embeddings_root_path: Union[str, Path]) -> None: return for embedding_document_dir in embeddings_directories: - self.embeddings_dict[embedding_document_dir.name] = Chroma(persist_directory=embedding_document_dir.path, - embedding_function=self.embedding_function) + self.embeddings_dict[embedding_document_dir.name] = self.engine( + persist_directory=embedding_document_dir.path, + embedding_function=self.embedding_function + ) filename_list = list(Path(embedding_document_dir).glob('*.storage_filename')) if filename_list: @@ -161,9 +318,60 @@ def get_md5_from_filename(self, filename): def get_filename_from_md5(self, md5): return self.embeddings_map_from_md5[md5] - def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None, - verbose=False) -> ( - Any, str): + def embed_document(self, doc_id, texts, metadatas): + if doc_id not in self.embeddings_dict.keys(): + self.embeddings_dict[doc_id] = self.engine.from_texts(texts, + embedding=self.embedding_function, + metadatas=metadatas, + collection_name=doc_id) + else: + # Workaround Chroma (?) breaking change + self.embeddings_dict[doc_id].delete_collection() + self.embeddings_dict[doc_id] = self.engine.from_texts(texts, + embedding=self.embedding_function, + metadatas=metadatas, + collection_name=doc_id) + + self.embeddings_root_path = None + + +class DocumentQAEngine: + llm = None + qa_chain_type = None + + default_prompts = { + 'stuff': stuff_prompt, + 'refine': refine_prompts, + "map_reduce": map_reduce_prompt, + "map_rerank": map_rerank_prompt + } + + def __init__(self, + llm, + data_storage: DataStorage, + qa_chain_type="stuff", + grobid_url=None, + memory=None + ): + + self.llm = llm + self.memory = memory + self.chain = load_qa_chain(llm, chain_type=qa_chain_type) + self.text_merger = TextMerger() + self.data_storage = data_storage + + if grobid_url: + self.grobid_processor = GrobidProcessor(grobid_url) + + def query_document( + self, + query: str, + doc_id, + output_parser=None, + context_size=4, + extraction_schema=None, + verbose=False + ) -> (Any, str): # self.load_embeddings(self.embeddings_root_path) if verbose: @@ -192,16 +400,22 @@ def query_document(self, query: str, doc_id, output_parser=None, context_size=4, else: return None, response, coordinates - def query_storage(self, query: str, doc_id, context_size=4): - documents = self._get_context(doc_id, query, context_size) + def query_storage(self, query: str, doc_id, context_size=4) -> (List[Document], list): + """ + Returns the context related to a given query + """ + documents, coordinates = self._get_context(doc_id, query, context_size) context_as_text = [doc.page_content for doc in documents] - return context_as_text + return context_as_text, coordinates def query_storage_and_embeddings(self, query: str, doc_id, context_size=4): - db = self.embeddings_dict[doc_id] - retriever = db.as_retriever(search_kwargs={"k": context_size}) - relevant_documents = retriever.get_relevant_documents(query, include=["embeddings"]) + """ + Returns both the context and the embedding information from a given query + """ + db = self.data_storage.embeddings_dict[doc_id] + retriever = db.as_retriever(search_kwargs={"k": context_size}, search_type="similarity_with_embeddings") + relevant_documents = retriever.get_relevant_documents(query) context_as_text = [doc.page_content for doc in relevant_documents] return context_as_text @@ -229,11 +443,11 @@ def _parse_json(self, response, output_parser): return parsed_output - def _run_query(self, doc_id, query, context_size=4): + def _run_query(self, doc_id, query, context_size=4) -> (List[Document], list): relevant_documents = self._get_context(doc_id, query, context_size) relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else [] for doc in - relevant_documents] # filter(lambda d: d['type'] == "sentence", relevant_documents)] + relevant_documents] response = self.chain.run(input_documents=relevant_documents, question=query) @@ -241,33 +455,40 @@ def _run_query(self, doc_id, query, context_size=4): self.memory.save_context({"input": query}, {"output": response}) return response, relevant_document_coordinates - def _get_context(self, doc_id, query, context_size=4): - db = self.embeddings_dict[doc_id] + def _get_context(self, doc_id, query, context_size=4) -> (List[Document], list): + db = self.data_storage.embeddings_dict[doc_id] retriever = db.as_retriever(search_kwargs={"k": context_size}) relevant_documents = retriever.get_relevant_documents(query) + relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else [] + for doc in + relevant_documents] if self.memory and len(self.memory.buffer_as_messages) > 0: relevant_documents.append( Document( page_content="""Following, the previous question and answers. Use these information only when in the question there are unspecified references:\n{}\n\n""".format( self.memory.buffer_as_str)) ) - return relevant_documents + return relevant_documents, relevant_document_coordinates - def get_all_context_by_document(self, doc_id): - """Return the full context from the document""" - db = self.embeddings_dict[doc_id] + def get_full_context_by_document(self, doc_id): + """ + Return the full context from the document + """ + db = self.data_storage.embeddings_dict[doc_id] docs = db.get() return docs['documents'] def _get_context_multiquery(self, doc_id, query, context_size=4): - db = self.embeddings_dict[doc_id].as_retriever(search_kwargs={"k": context_size}) + db = self.data_storage.embeddings_dict[doc_id].as_retriever(search_kwargs={"k": context_size}) multi_query_retriever = MultiQueryRetriever.from_llm(retriever=db, llm=self.llm) relevant_documents = multi_query_retriever.get_relevant_documents(query) return relevant_documents def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False): """ - Extract text from documents using Grobid, if chunk_size is < 0 it keeps each paragraph separately + Extract text from documents using Grobid. + - if chunk_size is < 0, keeps each paragraph separately + - if chunk_size > 0, aggregate all paragraphs and split them again using an approximate chunk size """ if verbose: print("File", pdf_file_path) @@ -307,7 +528,13 @@ def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, return texts, metadatas, ids - def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_overlap=0.1): + def create_memory_embeddings( + self, + pdf_path, + doc_id=None, + chunk_size=500, + perc_overlap=0.1 + ): texts, metadata, ids = self.get_text_from_document( pdf_path, chunk_size=chunk_size, @@ -317,25 +544,17 @@ def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_o else: hash = metadata[0]['hash'] - if hash not in self.embeddings_dict.keys(): - self.embeddings_dict[hash] = Chroma.from_texts(texts, - embedding=self.embedding_function, - metadatas=metadata, - collection_name=hash) - else: - # if 'documents' in self.embeddings_dict[hash].get() and len(self.embeddings_dict[hash].get()['documents']) == 0: - # self.embeddings_dict[hash].delete(ids=self.embeddings_dict[hash].get()['ids']) - self.embeddings_dict[hash].delete_collection() - self.embeddings_dict[hash] = Chroma.from_texts(texts, - embedding=self.embedding_function, - metadatas=metadata, - collection_name=hash) - - self.embeddings_root_path = None + self.data_storage.embed_document(hash, texts, metadata) return hash - def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0.1, include_biblio=False): + def create_embeddings( + self, + pdfs_dir_path: Path, + chunk_size=500, + perc_overlap=0.1, + include_biblio=False + ): input_files = [] for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False): for file_ in files: @@ -347,17 +566,16 @@ def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0. desc="Grobid + embeddings processing"): md5 = self.calculate_md5(input_file) - data_path = os.path.join(self.embeddings_root_path, md5) + data_path = os.path.join(self.data_storage.embeddings_root_path, md5) if os.path.exists(data_path): print(data_path, "exists. Skipping it ") continue - include = ["biblio"] if include_biblio else [] + # include = ["biblio"] if include_biblio else [] texts, metadata, ids = self.get_text_from_document( input_file, chunk_size=chunk_size, - perc_overlap=perc_overlap, - include=include) + perc_overlap=perc_overlap) filename = metadata[0]['filename'] vector_db_document = Chroma.from_texts(texts, diff --git a/requirements.txt b/requirements.txt index 830b718..3038528 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,10 +4,10 @@ grobid-client-python==0.0.7 grobid_tei_xml==0.1.3 # Utils -tqdm==4.66.1 +tqdm==4.66.2 pyyaml==6.0.1 -pytest==7.4.3 -streamlit==1.29.0 +pytest==8.1.1 +streamlit==1.33.0 lxml Beautifulsoup4 python-dotenv @@ -15,13 +15,13 @@ watchdog dateparser # LLM -chromadb==0.4.19 -tiktoken==0.4.0 -openai==0.27.7 -langchain==0.0.350 -langchain-core==0.1.0 +chromadb==0.4.24 +tiktoken==0.6.0 +openai==1.16.2 +langchain==0.1.14 +langchain-core==0.1.40 typing-inspect==0.9.0 -typing_extensions==4.8.0 -pydantic==2.4.2 -sentence_transformers==2.2.2 +typing_extensions==4.11.0 +pydantic==2.6.4 +sentence_transformers==2.6.1 streamlit-pdf-viewer \ No newline at end of file diff --git a/streamlit_app.py b/streamlit_app.py index daf88fa..e8535dc 100644 --- a/streamlit_app.py +++ b/streamlit_app.py @@ -9,15 +9,16 @@ from langchain.memory import ConversationBufferWindowMemory from streamlit_pdf_viewer import pdf_viewer +from document_qa.ner_client_generic import NERClientGeneric + dotenv.load_dotenv(override=True) import streamlit as st from langchain.chat_models import ChatOpenAI from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings -from document_qa.document_qa_engine import DocumentQAEngine +from document_qa.document_qa_engine import DocumentQAEngine, DataStorage from document_qa.grobid_processors import GrobidAggregationProcessor, decorate_text_with_annotations -from grobid_client_generic import GrobidClientGeneric OPENAI_MODELS = ['gpt-3.5-turbo', "gpt-4", @@ -168,14 +169,15 @@ def init_qa(model, api_key=None): st.stop() return - return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'], memory=st.session_state['memory']) + storage = DataStorage(embeddings) + return DocumentQAEngine(chat, storage, grobid_url=os.environ['GROBID_URL'], memory=st.session_state['memory']) @st.cache_resource def init_ner(): quantities_client = QuantitiesAPI(os.environ['GROBID_QUANTITIES_URL'], check_server=True) - materials_client = GrobidClientGeneric(ping=True) + materials_client = NERClientGeneric(ping=True) config_materials = { 'grobid': { "server": os.environ['GROBID_MATERIALS_URL'], @@ -190,10 +192,8 @@ def init_ner(): materials_client.set_config(config_materials) - gqa = GrobidAggregationProcessor(None, - grobid_quantities_client=quantities_client, - grobid_superconductors_client=materials_client - ) + gqa = GrobidAggregationProcessor(grobid_quantities_client=quantities_client, + grobid_superconductors_client=materials_client) return gqa @@ -340,9 +340,12 @@ def play_old_messages(): st.session_state['pdf_rendering'] = st.radio( "PDF rendering mode", - {"PDF.JS", "Native browser engine"}, - index=1, + ("unwrap", "legacy_embed"), + index=0, disabled=not uploaded_file, + help="PDF rendering engine." + "Note: The Legacy PDF viewer does not support annotations and might not work on Chrome.", + format_func=lambda q: "Legacy PDF Viewer" if q == "legacy_embed" else "Streamlit PDF Viewer (Pdf.js)" ) st.divider() @@ -441,7 +444,8 @@ def generate_color_gradient(num_elements): text_response = None if mode == "Embeddings": with st.spinner("Generating LLM response..."): - text_response = st.session_state['rqa'][model].query_storage_and_embeddings(question, st.session_state.doc_id, + text_response, coordinates = st.session_state['rqa'][model].query_storage(question, + st.session_state.doc_id, context_size=context_size) elif mode == "LLM": with st.spinner("Generating response..."): @@ -449,14 +453,14 @@ def generate_color_gradient(num_elements): st.session_state.doc_id, context_size=context_size) - annotations = [[GrobidAggregationProcessor.box_to_dict([cs for cs in c.split(",")]) for c in coord_doc] - for coord_doc in coordinates] - gradients = generate_color_gradient(len(annotations)) - for i, color in enumerate(gradients): - for annotation in annotations[i]: - annotation['color'] = color - st.session_state['annotations'] = [annotation for annotation_doc in annotations for annotation in - annotation_doc] + annotations = [[GrobidAggregationProcessor.box_to_dict([cs for cs in c.split(",")]) for c in coord_doc] + for coord_doc in coordinates] + gradients = generate_color_gradient(len(annotations)) + for i, color in enumerate(gradients): + for annotation in annotations[i]: + annotation['color'] = color + st.session_state['annotations'] = [annotation for annotation_doc in annotations for annotation in + annotation_doc] if not text_response: st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.") @@ -486,5 +490,5 @@ def generate_color_gradient(num_elements): height=800, annotation_outline_size=1, annotations=st.session_state['annotations'], - rendering='unwrap' if st.session_state['pdf_rendering'] == 'PDF.JS' else 'legacy_embed' + rendering=st.session_state['pdf_rendering'] ) From 9c16287298b60d0470dcc01dc92af9c06b75f3ef Mon Sep 17 00:00:00 2001 From: Luca Foppiano <Foppiano.Luca@nims.go.jp> Date: Mon, 8 Apr 2024 22:20:30 +0900 Subject: [PATCH 04/11] add coverage --- .github/workflows/ci-build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml index 33a395c..53e7aa8 100644 --- a/.github/workflows/ci-build.yml +++ b/.github/workflows/ci-build.yml @@ -26,7 +26,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install --upgrade flake8 pytest pycodestyle + pip install --upgrade flake8 pytest pycodestyle pytest-cov if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Lint with flake8 run: | From 64eade5d99e7e7072022c701ed881124d1c14b0f Mon Sep 17 00:00:00 2001 From: Luca Foppiano <Foppiano.Luca@nims.go.jp> Date: Mon, 8 Apr 2024 22:50:27 +0900 Subject: [PATCH 05/11] update dockerfile --- Dockerfile | 2 -- 1 file changed, 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index 69da66e..75f97e9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,8 +15,6 @@ RUN pip3 install -r requirements.txt COPY .streamlit ./.streamlit COPY document_qa ./document_qa -COPY grobid_client_generic.py . -COPY client.py . COPY streamlit_app.py . # extract version From a5de09e8b02100b2fdc8258b4cc9c363764986be Mon Sep 17 00:00:00 2001 From: Luca Foppiano <Foppiano.Luca@nims.go.jp> Date: Tue, 9 Apr 2024 09:39:22 +0900 Subject: [PATCH 06/11] fix breaking change in API --- document_qa/document_qa_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/document_qa/document_qa_engine.py b/document_qa/document_qa_engine.py index a5db03a..97d9618 100644 --- a/document_qa/document_qa_engine.py +++ b/document_qa/document_qa_engine.py @@ -494,7 +494,7 @@ def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, print("File", pdf_file_path) filename = Path(pdf_file_path).stem coordinates = True # if chunk_size == -1 else False - structure = self.grobid_processor.process(pdf_file_path, coordinates=coordinates) + structure = self.grobid_processor.process_structure(pdf_file_path, coordinates=coordinates) biblio = structure['biblio'] biblio['filename'] = filename.replace(" ", "_") From 0188e45594a0d68f86f44b926986108e3cd018b1 Mon Sep 17 00:00:00 2001 From: Luca Foppiano <Foppiano.Luca@nims.go.jp> Date: Tue, 9 Apr 2024 17:39:34 +0900 Subject: [PATCH 07/11] add query analyzer with min and avg similarity --- document_qa/document_qa_engine.py | 192 ++++++------------------------ document_qa/langchain.py | 141 ++++++++++++++++++++++ requirements.txt | 4 +- 3 files changed, 180 insertions(+), 157 deletions(-) create mode 100644 document_qa/langchain.py diff --git a/document_qa/document_qa_engine.py b/document_qa/document_qa_engine.py index 97d9618..2a13043 100644 --- a/document_qa/document_qa_engine.py +++ b/document_qa/document_qa_engine.py @@ -1,35 +1,23 @@ import copy import os from pathlib import Path -from typing import Union, Any, Optional, List, Dict, Tuple, ClassVar, Collection +from typing import Union, Any, List import tiktoken from langchain.chains import create_extraction_chain from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \ map_rerank_prompt +from langchain.evaluation import PairwiseEmbeddingDistanceEvalChain, load_evaluator, EmbeddingDistance from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate from langchain.retrievers import MultiQueryRetriever from langchain.schema import Document -from langchain_community.vectorstores.chroma import Chroma, DEFAULT_K -from langchain_community.vectorstores.faiss import FAISS -from langchain_core.callbacks import CallbackManagerForRetrieverRun -from langchain_core.utils import xor_args -from langchain_core.vectorstores import VectorStore, VectorStoreRetriever +from langchain_community.vectorstores.chroma import Chroma +from langchain_core.vectorstores import VectorStore from tqdm import tqdm +# from document_qa.embedding_visualiser import QueryVisualiser from document_qa.grobid_processors import GrobidProcessor - - -def _results_to_docs_scores_and_embeddings(results: Any) -> List[Tuple[Document, float, List[float]]]: - return [ - (Document(page_content=result[0], metadata=result[1] or {}), result[2], result[3]) - for result in zip( - results["documents"][0], - results["metadatas"][0], - results["distances"][0], - results["embeddings"][0], - ) - ] +from document_qa.langchain import ChromaAdvancedRetrieval class TextMerger: @@ -117,135 +105,6 @@ def __init__( self.persist_directory = persist_directory -class AdvancedVectorStoreRetriever(VectorStoreRetriever): - allowed_search_types: ClassVar[Collection[str]] = ( - "similarity", - "similarity_score_threshold", - "mmr", - "similarity_with_embeddings" - ) - - def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun - ) -> List[Document]: - if self.search_type == "similarity": - docs = self.vectorstore.similarity_search(query, **self.search_kwargs) - elif self.search_type == "similarity_score_threshold": - docs_and_similarities = ( - self.vectorstore.similarity_search_with_relevance_scores( - query, **self.search_kwargs - ) - ) - for doc, similarity in docs_and_similarities: - if '__similarity' not in doc.metadata.keys(): - doc.metadata['__similarity'] = similarity - - docs = [doc for doc, _ in docs_and_similarities] - elif self.search_type == "mmr": - docs = self.vectorstore.max_marginal_relevance_search( - query, **self.search_kwargs - ) - elif self.search_type == "similarity_with_embeddings": - docs_scores_and_embeddings = ( - self.vectorstore.advanced_similarity_search( - query, **self.search_kwargs - ) - ) - - for doc, score, embeddings in docs_scores_and_embeddings: - if '__embeddings' not in doc.metadata.keys(): - doc.metadata['__embeddings'] = embeddings - if '__similarity' not in doc.metadata.keys(): - doc.metadata['__similarity'] = score - - docs = [doc for doc, _, _ in docs_scores_and_embeddings] - else: - raise ValueError(f"search_type of {self.search_type} not allowed.") - return docs - - -class AdvancedVectorStore(VectorStore): - def as_retriever(self, **kwargs: Any) -> AdvancedVectorStoreRetriever: - tags = kwargs.pop("tags", None) or [] - tags.extend(self._get_retriever_tags()) - return AdvancedVectorStoreRetriever(vectorstore=self, **kwargs, tags=tags) - - -class ChromaAdvancedRetrieval(Chroma, AdvancedVectorStore): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - @xor_args(("query_texts", "query_embeddings")) - def __query_collection( - self, - query_texts: Optional[List[str]] = None, - query_embeddings: Optional[List[List[float]]] = None, - n_results: int = 4, - where: Optional[Dict[str, str]] = None, - where_document: Optional[Dict[str, str]] = None, - **kwargs: Any, - ) -> List[Document]: - """Query the chroma collection.""" - try: - import chromadb # noqa: F401 - except ImportError: - raise ValueError( - "Could not import chromadb python package. " - "Please install it with `pip install chromadb`." - ) - return self._collection.query( - query_texts=query_texts, - query_embeddings=query_embeddings, - n_results=n_results, - where=where, - where_document=where_document, - **kwargs, - ) - - def advanced_similarity_search( - self, - query: str, - k: int = DEFAULT_K, - filter: Optional[Dict[str, str]] = None, - **kwargs: Any, - ) -> [List[Document], float, List[float]]: - docs_scores_and_embeddings = self.similarity_search_with_scores_and_embeddings(query, k, filter=filter) - return docs_scores_and_embeddings - - def similarity_search_with_scores_and_embeddings( - self, - query: str, - k: int = DEFAULT_K, - filter: Optional[Dict[str, str]] = None, - where_document: Optional[Dict[str, str]] = None, - **kwargs: Any, - ) -> List[Tuple[Document, float, List[float]]]: - - if self._embedding_function is None: - results = self.__query_collection( - query_texts=[query], - n_results=k, - where=filter, - where_document=where_document, - include=['metadatas', 'documents', 'embeddings', 'distances'] - ) - else: - query_embedding = self._embedding_function.embed_query(query) - results = self.__query_collection( - query_embeddings=[query_embedding], - n_results=k, - where=filter, - where_document=where_document, - include=['metadatas', 'documents', 'embeddings', 'distances'] - ) - - return _results_to_docs_scores_and_embeddings(results) - - -class FAISSAdvancedRetrieval(FAISS): - pass - - class NER_Retrival(VectorStore): """ This class implement a retrieval based on NER models. @@ -256,7 +115,6 @@ class NER_Retrival(VectorStore): engines = { 'chroma': ChromaAdvancedRetrieval, - 'faiss': FAISSAdvancedRetrieval, 'ner': NER_Retrival } @@ -409,7 +267,7 @@ def query_storage(self, query: str, doc_id, context_size=4) -> (List[Document], context_as_text = [doc.page_content for doc in documents] return context_as_text, coordinates - def query_storage_and_embeddings(self, query: str, doc_id, context_size=4): + def query_storage_and_embeddings(self, query: str, doc_id, context_size=4) -> List[Document]: """ Returns both the context and the embedding information from a given query """ @@ -417,10 +275,35 @@ def query_storage_and_embeddings(self, query: str, doc_id, context_size=4): retriever = db.as_retriever(search_kwargs={"k": context_size}, search_type="similarity_with_embeddings") relevant_documents = retriever.get_relevant_documents(query) - context_as_text = [doc.page_content for doc in relevant_documents] - return context_as_text + return relevant_documents + + def analyse_query(self, query, doc_id, context_size=4): + db = self.data_storage.embeddings_dict[doc_id] + # retriever = db.as_retriever( + # search_kwargs={"k": context_size, 'score_threshold': 0.0}, + # search_type="similarity_score_threshold" + # ) + retriever = db.as_retriever(search_kwargs={"k": context_size}, search_type="similarity_with_embeddings") + relevant_documents = retriever.get_relevant_documents(query) + relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else [] + for doc in + relevant_documents] + all_documents = db.get(include=['documents', 'metadatas', 'embeddings']) + # all_documents_embeddings = all_documents["embeddings"] + # query_embedding = db._embedding_function.embed_query(query) + + # distance_evaluator = load_evaluator("pairwise_embedding_distance", + # embeddings=db._embedding_function, + # distance_metric=EmbeddingDistance.EUCLIDEAN) - # chroma_collection.get(include=['embeddings'])['embeddings'] + # distance_evaluator.evaluate_string_pairs(query=query_embedding, documents="") + + similarities = [doc.metadata['__similarity'] for doc in relevant_documents] + min_similarity = min(similarities) + mean_similarity = sum(similarities) / len(similarities) + coefficient = min_similarity - mean_similarity + + return f"Coefficient: {coefficient}, (Min similarity {min_similarity}, Mean similarity: {mean_similarity})", relevant_document_coordinates def _parse_json(self, response, output_parser): system_message = "You are an useful assistant expert in materials science, physics, and chemistry " \ @@ -444,10 +327,7 @@ def _parse_json(self, response, output_parser): return parsed_output def _run_query(self, doc_id, query, context_size=4) -> (List[Document], list): - relevant_documents = self._get_context(doc_id, query, context_size) - relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else [] - for doc in - relevant_documents] + relevant_documents, relevant_document_coordinates = self._get_context(doc_id, query, context_size) response = self.chain.run(input_documents=relevant_documents, question=query) diff --git a/document_qa/langchain.py b/document_qa/langchain.py new file mode 100644 index 0000000..30c1467 --- /dev/null +++ b/document_qa/langchain.py @@ -0,0 +1,141 @@ +from pathlib import Path +from typing import Any, Optional, List, Dict, Tuple, ClassVar, Collection + +from langchain.schema import Document +from langchain_community.vectorstores.chroma import Chroma, DEFAULT_K +from langchain_core.callbacks import CallbackManagerForRetrieverRun +from langchain_core.utils import xor_args +from langchain_core.vectorstores import VectorStore, VectorStoreRetriever + + +class AdvancedVectorStoreRetriever(VectorStoreRetriever): + allowed_search_types: ClassVar[Collection[str]] = ( + "similarity", + "similarity_score_threshold", + "mmr", + "similarity_with_embeddings" + ) + + def _get_relevant_documents( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun + ) -> List[Document]: + + if self.search_type == "similarity_with_embeddings": + docs_scores_and_embeddings = ( + self.vectorstore.advanced_similarity_search( + query, **self.search_kwargs + ) + ) + + for doc, score, embeddings in docs_scores_and_embeddings: + if '__embeddings' not in doc.metadata.keys(): + doc.metadata['__embeddings'] = embeddings + if '__similarity' not in doc.metadata.keys(): + doc.metadata['__similarity'] = score + + docs = [doc for doc, _, _ in docs_scores_and_embeddings] + elif self.search_type == "similarity_score_threshold": + docs_and_similarities = ( + self.vectorstore.similarity_search_with_relevance_scores( + query, **self.search_kwargs + ) + ) + for doc, similarity in docs_and_similarities: + if '__similarity' not in doc.metadata.keys(): + doc.metadata['__similarity'] = similarity + + docs = [doc for doc, _ in docs_and_similarities] + else: + docs = super()._get_relevant_documents(query, run_manager=run_manager) + + return docs + + +class AdvancedVectorStore(VectorStore): + def as_retriever(self, **kwargs: Any) -> AdvancedVectorStoreRetriever: + tags = kwargs.pop("tags", None) or [] + tags.extend(self._get_retriever_tags()) + return AdvancedVectorStoreRetriever(vectorstore=self, **kwargs, tags=tags) + + +class ChromaAdvancedRetrieval(Chroma, AdvancedVectorStore): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + @xor_args(("query_texts", "query_embeddings")) + def __query_collection( + self, + query_texts: Optional[List[str]] = None, + query_embeddings: Optional[List[List[float]]] = None, + n_results: int = 4, + where: Optional[Dict[str, str]] = None, + where_document: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> List[Document]: + """Query the chroma collection.""" + try: + import chromadb # noqa: F401 + except ImportError: + raise ValueError( + "Could not import chromadb python package. " + "Please install it with `pip install chromadb`." + ) + return self._collection.query( + query_texts=query_texts, + query_embeddings=query_embeddings, + n_results=n_results, + where=where, + where_document=where_document, + **kwargs, + ) + + def advanced_similarity_search( + self, + query: str, + k: int = DEFAULT_K, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> [List[Document], float, List[float]]: + docs_scores_and_embeddings = self.similarity_search_with_scores_and_embeddings(query, k, filter=filter) + return docs_scores_and_embeddings + + def similarity_search_with_scores_and_embeddings( + self, + query: str, + k: int = DEFAULT_K, + filter: Optional[Dict[str, str]] = None, + where_document: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float, List[float]]]: + + if self._embedding_function is None: + results = self.__query_collection( + query_texts=[query], + n_results=k, + where=filter, + where_document=where_document, + include=['metadatas', 'documents', 'embeddings', 'distances'] + ) + else: + query_embedding = self._embedding_function.embed_query(query) + results = self.__query_collection( + query_embeddings=[query_embedding], + n_results=k, + where=filter, + where_document=where_document, + include=['metadatas', 'documents', 'embeddings', 'distances'] + ) + + return _results_to_docs_scores_and_embeddings(results) + + +def _results_to_docs_scores_and_embeddings(results: Any) -> List[Tuple[Document, float, List[float]]]: + return [ + (Document(page_content=result[0], metadata=result[1] or {}), result[2], result[3]) + for result in zip( + results["documents"][0], + results["metadatas"][0], + results["distances"][0], + results["embeddings"][0], + ) + ] diff --git a/requirements.txt b/requirements.txt index 3038528..9b4f9bb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,4 +24,6 @@ typing-inspect==0.9.0 typing_extensions==4.11.0 pydantic==2.6.4 sentence_transformers==2.6.1 -streamlit-pdf-viewer \ No newline at end of file +streamlit-pdf-viewer +umap-learn +plotly \ No newline at end of file From d74cacde02f374a2dd729ba25de49ad80ba71ba9 Mon Sep 17 00:00:00 2001 From: Luca Foppiano <Foppiano.Luca@nims.go.jp> Date: Tue, 9 Apr 2024 17:39:39 +0900 Subject: [PATCH 08/11] update application --- streamlit_app.py | 64 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 44 insertions(+), 20 deletions(-) diff --git a/streamlit_app.py b/streamlit_app.py index e8535dc..643baa7 100644 --- a/streamlit_app.py +++ b/streamlit_app.py @@ -5,8 +5,11 @@ import dotenv from grobid_quantities.quantities import QuantitiesAPI -from langchain.llms.huggingface_hub import HuggingFaceHub from langchain.memory import ConversationBufferWindowMemory +from langchain_community.chat_models.openai import ChatOpenAI +from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings +from langchain_community.embeddings.openai import OpenAIEmbeddings +from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint from streamlit_pdf_viewer import pdf_viewer from document_qa.ner_client_generic import NERClientGeneric @@ -14,9 +17,6 @@ dotenv.load_dotenv(override=True) import streamlit as st -from langchain.chat_models import ChatOpenAI -from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings - from document_qa.document_qa_engine import DocumentQAEngine, DataStorage from document_qa.grobid_processors import GrobidAggregationProcessor, decorate_text_with_annotations @@ -157,9 +157,11 @@ def init_qa(model, api_key=None): embeddings = OpenAIEmbeddings() elif model in OPEN_MODELS: - chat = HuggingFaceHub( + chat = HuggingFaceEndpoint( repo_id=OPEN_MODELS[model], - model_kwargs={"temperature": 0.01, "max_length": 4096, "max_new_tokens": 2048} + temperature=0.01, + max_new_tokens=2048, + model_kwargs={"max_length": 4096} ) embeddings = HuggingFaceEmbeddings( model_name="all-MiniLM-L6-v2") @@ -305,16 +307,24 @@ def play_old_messages(): disabled=not uploaded_file ) +query_modes = { + "llm": "LLM Q/A", + "embeddings": "Embeddings", + "question_coefficient": "Question coefficient" +} + with st.sidebar: st.header("Settings") mode = st.radio( "Query mode", - ("LLM", "Embeddings"), + ("llm", "embeddings", "question_coefficient"), disabled=not uploaded_file, index=0, horizontal=True, + format_func=lambda x: query_modes[x], help="LLM will respond the question, Embedding will show the " - "paragraphs relevant to the question in the paper." + "relevant paragraphs to the question in the paper. " + "Question coefficient attempt to estimate how effective the question will be answered." ) # Add a checkbox for showing annotations @@ -429,10 +439,12 @@ def generate_color_gradient(num_elements): if st.session_state.loaded_embeddings and question and len(question) > 0 and st.session_state.doc_id: for message in st.session_state.messages: with st.chat_message(message["role"]): - if message['mode'] == "LLM": + if message['mode'] == "llm": st.markdown(message["content"], unsafe_allow_html=True) - elif message['mode'] == "Embeddings": + elif message['mode'] == "embeddings": st.write(message["content"]) + if message['mode'] == "question_coefficient": + st.markdown(message["content"], unsafe_allow_html=True) if model not in st.session_state['rqa']: st.error("The API Key for the " + model + " is missing. Please add it before sending any query. `") st.stop() @@ -442,16 +454,28 @@ def generate_color_gradient(num_elements): st.session_state.messages.append({"role": "user", "mode": mode, "content": question}) text_response = None - if mode == "Embeddings": + if mode == "embeddings": + with st.spinner("Fetching the relevant context..."): + text_response, coordinates = st.session_state['rqa'][model].query_storage( + question, + st.session_state.doc_id, + context_size=context_size + ) + elif mode == "llm": with st.spinner("Generating LLM response..."): - text_response, coordinates = st.session_state['rqa'][model].query_storage(question, - st.session_state.doc_id, - context_size=context_size) - elif mode == "LLM": - with st.spinner("Generating response..."): - _, text_response, coordinates = st.session_state['rqa'][model].query_document(question, - st.session_state.doc_id, - context_size=context_size) + _, text_response, coordinates = st.session_state['rqa'][model].query_document( + question, + st.session_state.doc_id, + context_size=context_size + ) + + elif mode == "question_coefficient": + with st.spinner("Estimate question/context relevancy..."): + text_response, coordinates = st.session_state['rqa'][model].analyse_query( + question, + st.session_state.doc_id, + context_size=context_size + ) annotations = [[GrobidAggregationProcessor.box_to_dict([cs for cs in c.split(",")]) for c in coord_doc] for coord_doc in coordinates] @@ -466,7 +490,7 @@ def generate_color_gradient(num_elements): st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.") with st.chat_message("assistant"): - if mode == "LLM": + if mode == "llm": if st.session_state['ner_processing']: with st.spinner("Processing NER on LLM response..."): entities = gqa.process_single_text(text_response) From 7bc374b1797899211ef41ad227bc31767fc5063c Mon Sep 17 00:00:00 2001 From: Luca Foppiano <Foppiano.Luca@nims.go.jp> Date: Mon, 6 May 2024 14:24:50 +0900 Subject: [PATCH 09/11] update to mistral v0.2, add selectable embeddings --- streamlit_app.py | 64 +++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 52 insertions(+), 12 deletions(-) diff --git a/streamlit_app.py b/streamlit_app.py index 643baa7..6ff5637 100644 --- a/streamlit_app.py +++ b/streamlit_app.py @@ -24,9 +24,23 @@ "gpt-4", "gpt-4-1106-preview"] +OPENAI_EMBEDDINGS = [ + 'text-embedding-ada-002', + 'text-embedding-3-large', + 'openai-text-embedding-3-small' +] + OPEN_MODELS = { - 'mistral-7b-instruct-v0.1': 'mistralai/Mistral-7B-Instruct-v0.1', + 'mistral-7b-instruct-v0.2': 'mistralai/Mistral-7B-Instruct-v0.2', "zephyr-7b-beta": 'HuggingFaceH4/zephyr-7b-beta' + # 'Phi-3-mini-128k-instruct': "microsoft/Phi-3-mini-128k-instruct", + # 'Phi-3-mini-4k-instruct': "microsoft/Phi-3-mini-4k-instruct" +} + +DEFAULT_OPEN_EMBEDDING_NAME = 'Default (all-MiniLM-L6-v2)' +OPEN_EMBEDDINGS = { + DEFAULT_OPEN_EMBEDDING_NAME: 'all-MiniLM-L6-v2', + 'Salesforce/SFR-Embedding-Mistral': 'Salesforce/SFR-Embedding-Mistral' } DISABLE_MEMORY = ['zephyr-7b-beta'] @@ -83,6 +97,9 @@ if 'pdf_rendering' not in st.session_state: st.session_state['pdf_rendering'] = None +if 'embeddings' not in st.session_state: + st.session_state['embeddings'] = None + st.set_page_config( page_title="Scientific Document Insights Q/A", page_icon="📝", @@ -139,24 +156,34 @@ def clear_memory(): # @st.cache_resource -def init_qa(model, api_key=None): +def init_qa(model, embeddings_name=None, api_key=None): ## For debug add: callbacks=[PromptLayerCallbackHandler(pl_tags=["langchain", "chatgpt", "document-qa"])]) if model in OPENAI_MODELS: + if embeddings_name is None: + embeddings_name = 'text-embedding-ada-002' + st.session_state['memory'] = ConversationBufferWindowMemory(k=4) if api_key: chat = ChatOpenAI(model_name=model, temperature=0, openai_api_key=api_key, frequency_penalty=0.1) - embeddings = OpenAIEmbeddings(openai_api_key=api_key) + if embeddings_name not in OPENAI_EMBEDDINGS: + st.error(f"The embeddings provided {embeddings_name} are not supported by this model {model}.") + st.stop() + return + embeddings = OpenAIEmbeddings(model=embeddings_name, openai_api_key=api_key) else: chat = ChatOpenAI(model_name=model, temperature=0, frequency_penalty=0.1) - embeddings = OpenAIEmbeddings() + embeddings = OpenAIEmbeddings(model=embeddings_name) elif model in OPEN_MODELS: + if embeddings_name is None: + embeddings_name = DEFAULT_OPEN_EMBEDDING_NAME + chat = HuggingFaceEndpoint( repo_id=OPEN_MODELS[model], temperature=0.01, @@ -164,7 +191,7 @@ def init_qa(model, api_key=None): model_kwargs={"max_length": 4096} ) embeddings = HuggingFaceEmbeddings( - model_name="all-MiniLM-L6-v2") + model_name=OPEN_EMBEDDINGS[embeddings_name]) st.session_state['memory'] = ConversationBufferWindowMemory(k=4) if model not in DISABLE_MEMORY else None else: st.error("The model was not loaded properly. Try reloading. ") @@ -231,15 +258,25 @@ def play_old_messages(): "Model:", options=OPENAI_MODELS + list(OPEN_MODELS.keys()), index=(OPENAI_MODELS + list(OPEN_MODELS.keys())).index( - "zephyr-7b-beta") if "DEFAULT_MODEL" not in os.environ or not os.environ["DEFAULT_MODEL"] else ( + "mistral-7b-instruct-v0.2") if "DEFAULT_MODEL" not in os.environ or not os.environ["DEFAULT_MODEL"] else ( OPENAI_MODELS + list(OPEN_MODELS.keys())).index(os.environ["DEFAULT_MODEL"]), placeholder="Select model", help="Select the LLM model:", disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded'] ) + embedding_choices = OPENAI_EMBEDDINGS if model in OPENAI_MODELS else OPEN_EMBEDDINGS + + st.session_state['embeddings'] = embedding_name = st.selectbox( + "Embeddings:", + options=embedding_choices, + index=0, + placeholder="Select embedding", + help="Select the Embedding function:", + disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded'] + ) st.markdown( - ":warning: [Usage disclaimer](https://github.com/lfoppiano/document-qa/tree/review-interface#disclaimer-on-data-security-and-privacy-%EF%B8%8F) :warning: ") + ":warning: [Usage disclaimer](https://github.com/lfoppiano/document-qa?tab=readme-ov-file#disclaimer-on-data-security-and-privacy-%EF%B8%8F) :warning: ") if (model in OPEN_MODELS) and model not in st.session_state['api_keys']: if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ: @@ -256,7 +293,7 @@ def play_old_messages(): st.session_state['api_keys'][model] = api_key # if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ: # os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key - st.session_state['rqa'][model] = init_qa(model) + st.session_state['rqa'][model] = init_qa(model, embedding_name) elif model in OPENAI_MODELS and model not in st.session_state['api_keys']: if 'OPENAI_API_KEY' not in os.environ: @@ -270,9 +307,9 @@ def play_old_messages(): with st.spinner("Preparing environment"): st.session_state['api_keys'][model] = api_key if 'OPENAI_API_KEY' not in os.environ: - st.session_state['rqa'][model] = init_qa(model, api_key) + st.session_state['rqa'][model] = init_qa(model, st.session_state['embeddings'], api_key) else: - st.session_state['rqa'][model] = init_qa(model) + st.session_state['rqa'][model] = init_qa(model, st.session_state['embeddings']) # else: # is_api_key_provided = st.session_state['api_key'] @@ -371,10 +408,13 @@ def play_old_messages(): st.header("Query mode (Advanced use)") st.markdown( - """By default, the mode is set to LLM (Language Model) which enables question/answering. You can directly ask questions related to the document content, and the system will answer the question using content from the document.""") + """By default, the mode is set to LLM (Language Model) which enables question/answering. + You can directly ask questions related to the document content, and the system will answer the question using content from the document.""") st.markdown( - """If you switch the mode to "Embedding," the system will return specific chunks from the document that are semantically related to your query. This mode helps to test why sometimes the answers are not satisfying or incomplete. """) + """If you switch the mode to "Embedding," the system will return specific chunks from the document + that are semantically related to your query. This mode helps to test why sometimes the answers are not + satisfying or incomplete. """) if uploaded_file and not st.session_state.loaded_embeddings: if model not in st.session_state['api_keys']: From ffa83eacd79dee5e28a743c707e596bbf2c92df7 Mon Sep 17 00:00:00 2001 From: Luca Foppiano <luca@foppiano.org> Date: Sat, 22 Jun 2024 22:00:12 +0900 Subject: [PATCH 10/11] remove zephyr --- streamlit_app.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/streamlit_app.py b/streamlit_app.py index 6ff5637..14d2d92 100644 --- a/streamlit_app.py +++ b/streamlit_app.py @@ -31,10 +31,9 @@ ] OPEN_MODELS = { - 'mistral-7b-instruct-v0.2': 'mistralai/Mistral-7B-Instruct-v0.2', - "zephyr-7b-beta": 'HuggingFaceH4/zephyr-7b-beta' + 'mistral-7b-instruct-v0.3': 'mistralai/Mistral-7B-Instruct-v0.2', # 'Phi-3-mini-128k-instruct': "microsoft/Phi-3-mini-128k-instruct", - # 'Phi-3-mini-4k-instruct': "microsoft/Phi-3-mini-4k-instruct" + 'Phi-3-mini-4k-instruct': "microsoft/Phi-3-mini-4k-instruct" } DEFAULT_OPEN_EMBEDDING_NAME = 'Default (all-MiniLM-L6-v2)' From 53c8debc4396ec14a8c2f958ea33e80dec59a2f9 Mon Sep 17 00:00:00 2001 From: Luca Foppiano <luca@foppiano.org> Date: Sat, 22 Jun 2024 22:04:47 +0900 Subject: [PATCH 11/11] get data availability statement as context for QA --- document_qa/grobid_processors.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/document_qa/grobid_processors.py b/document_qa/grobid_processors.py index a0fd022..e8cc8e3 100644 --- a/document_qa/grobid_processors.py +++ b/document_qa/grobid_processors.py @@ -183,6 +183,7 @@ def parse_grobid_xml(self, text, coordinates=False): }) text_blocks_body = get_xml_nodes_body(soup, verbose=False, use_paragraphs=True) + text_blocks_body.extend(get_xml_nodes_back(soup, verbose=False, use_paragraphs=True)) use_paragraphs = True if not use_paragraphs: @@ -800,6 +801,20 @@ def get_xml_nodes_body(soup: object, use_paragraphs: bool = True, verbose: bool return nodes +def get_xml_nodes_back(soup: object, use_paragraphs: bool = True, verbose: bool = False) -> list: + nodes = [] + tag_name = "p" if use_paragraphs else "s" + for child in soup.TEI.children: + if child.name == 'text': + nodes.extend( + [subsubchild for subchild in child.find_all("back") for subsubchild in subchild.find_all(tag_name)]) + + if verbose: + print(str(nodes)) + + return nodes + + def get_xml_nodes_figures(soup: object, verbose: bool = False) -> list: children = [] for child in soup.TEI.children: