diff --git a/document_qa/document_qa_engine.py b/document_qa/document_qa_engine.py
index 6c2f3b3..21aa93d 100644
--- a/document_qa/document_qa_engine.py
+++ b/document_qa/document_qa_engine.py
@@ -12,7 +12,7 @@
from langchain.vectorstores import Chroma
from tqdm import tqdm
-from grobid_processors import GrobidProcessor
+from document_qa.grobid_processors import GrobidProcessor
class DocumentQAEngine:
diff --git a/document_qa/grobid_processors.py b/document_qa/grobid_processors.py
index d87cb25..e21b1f1 100644
--- a/document_qa/grobid_processors.py
+++ b/document_qa/grobid_processors.py
@@ -413,7 +413,8 @@ def __init__(self, grobid_superconductors_client):
def extract_materials(self, text):
preprocessed_text = text.strip()
- status, result = self.grobid_superconductors_client.process_text(preprocessed_text, "processText_disable_linking")
+ status, result = self.grobid_superconductors_client.process_text(preprocessed_text,
+ "processText_disable_linking")
if status != 200:
result = {}
@@ -679,6 +680,7 @@ def parse_xml(self, text):
return output_data
+
def get_children_list_supermat(soup, use_paragraphs=False, verbose=False):
children = []
@@ -697,6 +699,7 @@ def get_children_list_supermat(soup, use_paragraphs=False, verbose=False):
return children
+
def get_children_list_grobid(soup: object, use_paragraphs: object = True, verbose: object = False) -> object:
children = []
@@ -739,4 +742,4 @@ def get_children_figures(soup: object, use_paragraphs: object = True, verbose: o
if verbose:
print(str(children))
- return children
\ No newline at end of file
+ return children
diff --git a/streamlit_app.py b/streamlit_app.py
index 3f6c939..a0ae432 100644
--- a/streamlit_app.py
+++ b/streamlit_app.py
@@ -42,6 +42,7 @@
if "messages" not in st.session_state:
st.session_state.messages = []
+
def new_file():
st.session_state['loaded_embeddings'] = None
st.session_state['doc_id'] = None
@@ -69,6 +70,7 @@ def init_qa(model):
return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
+
@st.cache_resource
def init_ner():
quantities_client = QuantitiesAPI(os.environ['GROBID_QUANTITIES_URL'], check_server=True)
@@ -89,14 +91,16 @@ def init_ner():
materials_client.set_config(config_materials)
gqa = GrobidAggregationProcessor(None,
- grobid_quantities_client=quantities_client,
- grobid_superconductors_client=materials_client
- )
+ grobid_quantities_client=quantities_client,
+ grobid_superconductors_client=materials_client
+ )
return gqa
+
gqa = init_ner()
+
def get_file_hash(fname):
hash_md5 = blake2b()
with open(fname, "rb") as f:
@@ -122,7 +126,7 @@ def play_old_messages():
is_api_key_provided = st.session_state['api_key']
model = st.sidebar.radio("Model (cannot be changed after selection or upload)",
- ("chatgpt-3.5-turbo", "mistral-7b-instruct-v0.1"),#, "llama-2-70b-chat"),
+ ("chatgpt-3.5-turbo", "mistral-7b-instruct-v0.1"), # , "llama-2-70b-chat"),
index=1,
captions=[
"ChatGPT 3.5 Turbo + Ada-002-text (embeddings)",
@@ -134,13 +138,15 @@ def play_old_messages():
if not st.session_state['api_key']:
if model == 'mistral-7b-instruct-v0.1' or model == 'llama-2-70b-chat':
- api_key = st.sidebar.text_input('Huggingface API Key', type="password")# if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ else os.environ['HUGGINGFACEHUB_API_TOKEN']
+ api_key = st.sidebar.text_input('Huggingface API Key',
+ type="password") # if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ else os.environ['HUGGINGFACEHUB_API_TOKEN']
if api_key:
st.session_state['api_key'] = is_api_key_provided = True
os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
st.session_state['rqa'] = init_qa(model)
elif model == 'chatgpt-3.5-turbo':
- api_key = st.sidebar.text_input('OpenAI API Key', type="password") #if 'OPENAI_API_KEY' not in os.environ else os.environ['OPENAI_API_KEY']
+ api_key = st.sidebar.text_input('OpenAI API Key',
+ type="password") # if 'OPENAI_API_KEY' not in os.environ else os.environ['OPENAI_API_KEY']
if api_key:
st.session_state['api_key'] = is_api_key_provided = True
os.environ['OPENAI_API_KEY'] = api_key
@@ -177,10 +183,12 @@ def play_old_messages():
st.markdown(
"""After entering your API Key (Open AI or Huggingface). Upload a scientific article as PDF document. You will see a spinner or loading indicator while the processing is in progress. Once the spinner stops, you can proceed to ask your questions.""")
- st.markdown('**NER on LLM responses**: The responses from the LLMs are post-processed to extract physical quantities, measurements and materials mentions.', unsafe_allow_html=True)
+ st.markdown(
+ '**NER on LLM responses**: The responses from the LLMs are post-processed to extract physical quantities, measurements and materials mentions.',
+ unsafe_allow_html=True)
if st.session_state['git_rev'] != "unknown":
st.markdown("**Revision number**: [" + st.session_state[
- 'git_rev'] + "](https://github.com/lfoppiano/document-qa/commit/" + st.session_state['git_rev'] + ")")
+ 'git_rev'] + "](https://github.com/lfoppiano/document-qa/commit/" + st.session_state['git_rev'] + ")")
st.header("Query mode (Advanced use)")
st.markdown(
@@ -219,11 +227,11 @@ def play_old_messages():
if mode == "Embeddings":
with st.spinner("Generating LLM response..."):
text_response = st.session_state['rqa'].query_storage(question, st.session_state.doc_id,
- context_size=context_size)
+ context_size=context_size)
elif mode == "LLM":
with st.spinner("Generating response..."):
_, text_response = st.session_state['rqa'].query_document(question, st.session_state.doc_id,
- context_size=context_size)
+ context_size=context_size)
if not text_response:
st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.")