diff --git a/document_qa/document_qa_engine.py b/document_qa/document_qa_engine.py index 6c2f3b3..21aa93d 100644 --- a/document_qa/document_qa_engine.py +++ b/document_qa/document_qa_engine.py @@ -12,7 +12,7 @@ from langchain.vectorstores import Chroma from tqdm import tqdm -from grobid_processors import GrobidProcessor +from document_qa.grobid_processors import GrobidProcessor class DocumentQAEngine: diff --git a/document_qa/grobid_processors.py b/document_qa/grobid_processors.py index d87cb25..e21b1f1 100644 --- a/document_qa/grobid_processors.py +++ b/document_qa/grobid_processors.py @@ -413,7 +413,8 @@ def __init__(self, grobid_superconductors_client): def extract_materials(self, text): preprocessed_text = text.strip() - status, result = self.grobid_superconductors_client.process_text(preprocessed_text, "processText_disable_linking") + status, result = self.grobid_superconductors_client.process_text(preprocessed_text, + "processText_disable_linking") if status != 200: result = {} @@ -679,6 +680,7 @@ def parse_xml(self, text): return output_data + def get_children_list_supermat(soup, use_paragraphs=False, verbose=False): children = [] @@ -697,6 +699,7 @@ def get_children_list_supermat(soup, use_paragraphs=False, verbose=False): return children + def get_children_list_grobid(soup: object, use_paragraphs: object = True, verbose: object = False) -> object: children = [] @@ -739,4 +742,4 @@ def get_children_figures(soup: object, use_paragraphs: object = True, verbose: o if verbose: print(str(children)) - return children \ No newline at end of file + return children diff --git a/streamlit_app.py b/streamlit_app.py index 3f6c939..a0ae432 100644 --- a/streamlit_app.py +++ b/streamlit_app.py @@ -42,6 +42,7 @@ if "messages" not in st.session_state: st.session_state.messages = [] + def new_file(): st.session_state['loaded_embeddings'] = None st.session_state['doc_id'] = None @@ -69,6 +70,7 @@ def init_qa(model): return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL']) + @st.cache_resource def init_ner(): quantities_client = QuantitiesAPI(os.environ['GROBID_QUANTITIES_URL'], check_server=True) @@ -89,14 +91,16 @@ def init_ner(): materials_client.set_config(config_materials) gqa = GrobidAggregationProcessor(None, - grobid_quantities_client=quantities_client, - grobid_superconductors_client=materials_client - ) + grobid_quantities_client=quantities_client, + grobid_superconductors_client=materials_client + ) return gqa + gqa = init_ner() + def get_file_hash(fname): hash_md5 = blake2b() with open(fname, "rb") as f: @@ -122,7 +126,7 @@ def play_old_messages(): is_api_key_provided = st.session_state['api_key'] model = st.sidebar.radio("Model (cannot be changed after selection or upload)", - ("chatgpt-3.5-turbo", "mistral-7b-instruct-v0.1"),#, "llama-2-70b-chat"), + ("chatgpt-3.5-turbo", "mistral-7b-instruct-v0.1"), # , "llama-2-70b-chat"), index=1, captions=[ "ChatGPT 3.5 Turbo + Ada-002-text (embeddings)", @@ -134,13 +138,15 @@ def play_old_messages(): if not st.session_state['api_key']: if model == 'mistral-7b-instruct-v0.1' or model == 'llama-2-70b-chat': - api_key = st.sidebar.text_input('Huggingface API Key', type="password")# if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ else os.environ['HUGGINGFACEHUB_API_TOKEN'] + api_key = st.sidebar.text_input('Huggingface API Key', + type="password") # if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ else os.environ['HUGGINGFACEHUB_API_TOKEN'] if api_key: st.session_state['api_key'] = is_api_key_provided = True os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key st.session_state['rqa'] = init_qa(model) elif model == 'chatgpt-3.5-turbo': - api_key = st.sidebar.text_input('OpenAI API Key', type="password") #if 'OPENAI_API_KEY' not in os.environ else os.environ['OPENAI_API_KEY'] + api_key = st.sidebar.text_input('OpenAI API Key', + type="password") # if 'OPENAI_API_KEY' not in os.environ else os.environ['OPENAI_API_KEY'] if api_key: st.session_state['api_key'] = is_api_key_provided = True os.environ['OPENAI_API_KEY'] = api_key @@ -177,10 +183,12 @@ def play_old_messages(): st.markdown( """After entering your API Key (Open AI or Huggingface). Upload a scientific article as PDF document. You will see a spinner or loading indicator while the processing is in progress. Once the spinner stops, you can proceed to ask your questions.""") - st.markdown('**NER on LLM responses**: The responses from the LLMs are post-processed to extract physical quantities, measurements and materials mentions.', unsafe_allow_html=True) + st.markdown( + '**NER on LLM responses**: The responses from the LLMs are post-processed to extract physical quantities, measurements and materials mentions.', + unsafe_allow_html=True) if st.session_state['git_rev'] != "unknown": st.markdown("**Revision number**: [" + st.session_state[ - 'git_rev'] + "](https://github.com/lfoppiano/document-qa/commit/" + st.session_state['git_rev'] + ")") + 'git_rev'] + "](https://github.com/lfoppiano/document-qa/commit/" + st.session_state['git_rev'] + ")") st.header("Query mode (Advanced use)") st.markdown( @@ -219,11 +227,11 @@ def play_old_messages(): if mode == "Embeddings": with st.spinner("Generating LLM response..."): text_response = st.session_state['rqa'].query_storage(question, st.session_state.doc_id, - context_size=context_size) + context_size=context_size) elif mode == "LLM": with st.spinner("Generating response..."): _, text_response = st.session_state['rqa'].query_document(question, st.session_state.doc_id, - context_size=context_size) + context_size=context_size) if not text_response: st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.")