Skip to content

Commit

Permalink
move settings on the sidebar, allow env variables
Browse files Browse the repository at this point in the history
  • Loading branch information
lfoppiano committed Oct 27, 2023
1 parent 64372e1 commit 88c1cba
Showing 1 changed file with 78 additions and 57 deletions.
135 changes: 78 additions & 57 deletions streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@
from grobid_client_generic import GrobidClientGeneric

if 'rqa' not in st.session_state:
st.session_state['rqa'] = None
st.session_state['rqa'] = {}

if 'api_key' not in st.session_state:
st.session_state['api_key'] = False

if 'api_keys' not in st.session_state:
st.session_state['api_keys'] = {}

if 'doc_id' not in st.session_state:
st.session_state['doc_id'] = None

Expand All @@ -42,13 +45,16 @@
if "messages" not in st.session_state:
st.session_state.messages = []

if 'ner_processing' not in st.session_state:
st.session_state['ner_processing'] = False


def new_file():
st.session_state['loaded_embeddings'] = None
st.session_state['doc_id'] = None


@st.cache_resource
# @st.cache_resource
def init_qa(model):
if model == 'chatgpt-3.5-turbo':
chat = PromptLayerChatOpenAI(model_name="gpt-3.5-turbo",
Expand All @@ -67,6 +73,7 @@ def init_qa(model):
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
else:
st.error("The model was not loaded properly. Try reloading. ")
st.stop()

return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])

Expand Down Expand Up @@ -94,7 +101,6 @@ def init_ner():
grobid_quantities_client=quantities_client,
grobid_superconductors_client=materials_client
)

return gqa


Expand Down Expand Up @@ -125,51 +131,52 @@ def play_old_messages():

is_api_key_provided = st.session_state['api_key']

model = st.sidebar.radio("Model (cannot be changed after selection or upload)",
("chatgpt-3.5-turbo", "mistral-7b-instruct-v0.1"), # , "llama-2-70b-chat"),
index=1,
captions=[
"ChatGPT 3.5 Turbo + Ada-002-text (embeddings)",
"Mistral-7B-Instruct-V0.1 + Sentence BERT (embeddings)"
# "LLama2-70B-Chat + Sentence BERT (embeddings)",
],
help="Select the model you want to use.",
disabled=is_api_key_provided)

if not st.session_state['api_key']:
with st.sidebar:
model = st.radio(
"Model (cannot be changed after selection or upload)",
("chatgpt-3.5-turbo", "mistral-7b-instruct-v0.1"), # , "llama-2-70b-chat"),
index=1,
captions=[
"ChatGPT 3.5 Turbo + Ada-002-text (embeddings)",
"Mistral-7B-Instruct-V0.1 + Sentence BERT (embeddings)"
# "LLama2-70B-Chat + Sentence BERT (embeddings)",
],
help="Select the model you want to use.")

if model == 'mistral-7b-instruct-v0.1' or model == 'llama-2-70b-chat':
api_key = st.sidebar.text_input('Huggingface API Key',
type="password") # if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ else os.environ['HUGGINGFACEHUB_API_TOKEN']
api_key = st.text_input('Huggingface API Key',
type="password") if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ else os.environ[
'HUGGINGFACEHUB_API_TOKEN']
st.markdown(
"Get it for [Open AI](https://platform.openai.com/account/api-keys) or [Huggingface](https://huggingface.co/docs/hub/security-tokens)")

if api_key:
st.session_state['api_key'] = is_api_key_provided = True
os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
st.session_state['rqa'] = init_qa(model)
st.session_state['api_keys']['mistral-7b-instruct-v0.1'] = api_key
if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
st.session_state['rqa'][model] = init_qa(model)

elif model == 'chatgpt-3.5-turbo':
api_key = st.sidebar.text_input('OpenAI API Key',
type="password") # if 'OPENAI_API_KEY' not in os.environ else os.environ['OPENAI_API_KEY']
api_key = st.text_input('OpenAI API Key', type="password") if 'OPENAI_API_KEY' not in os.environ else \
os.environ['OPENAI_API_KEY']
st.markdown(
"Get it for [Open AI](https://platform.openai.com/account/api-keys) or [Huggingface](https://huggingface.co/docs/hub/security-tokens)")
if api_key:
st.session_state['api_key'] = is_api_key_provided = True
os.environ['OPENAI_API_KEY'] = api_key
st.session_state['rqa'] = init_qa(model)
else:
is_api_key_provided = st.session_state['api_key']
st.session_state['api_keys']['chatgpt-3.5-turbo'] = api_key
if 'OPENAI_API_KEY' not in os.environ:
os.environ['OPENAI_API_KEY'] = api_key
st.session_state['rqa'][model] = init_qa(model)
# else:
# is_api_key_provided = st.session_state['api_key']

st.title("📝 Scientific Document Insight Q&A")
st.subheader("Upload a scientific article in PDF, ask questions, get insights.")

upload_col, radio_col, context_col = st.columns([7, 2, 2])
with upload_col:
uploaded_file = st.file_uploader("Upload an article", type=("pdf", "txt"), on_change=new_file,
disabled=not is_api_key_provided,
help="The full-text is extracted using Grobid. ")
with radio_col:
mode = st.radio("Query mode", ("LLM", "Embeddings"), disabled=not uploaded_file, index=0,
help="LLM will respond the question, Embedding will show the "
"paragraphs relevant to the question in the paper.")
with context_col:
context_size = st.slider("Context size", 3, 10, value=4,
help="Number of paragraphs to consider when answering a question",
disabled=not uploaded_file)
uploaded_file = st.file_uploader("Upload an article", type=("pdf", "txt"), on_change=new_file,
disabled=not is_api_key_provided,
help="The full-text is extracted using Grobid. ")

question = st.chat_input(
"Ask something about the article",
Expand All @@ -178,14 +185,29 @@ def play_old_messages():
)

with st.sidebar:
st.header("Settings")
mode = st.radio("Query mode", ("LLM", "Embeddings"), disabled=not uploaded_file, index=0, horizontal=True,
help="LLM will respond the question, Embedding will show the "
"paragraphs relevant to the question in the paper.")
chunk_size = st.slider("Chunks size", 100, 2000, value=250,
help="Size of chunks in which the document is partitioned",
disabled=not uploaded_file)
context_size = st.slider("Context size", 3, 10, value=4,
help="Number of chunks to consider when answering a question",
disabled=not uploaded_file)

st.session_state['ner_processing'] = st.checkbox("NER processing on LLM response")
st.markdown(
'**NER on LLM responses**: The responses from the LLMs are post-processed to extract <span style="color:orange">physical quantities, measurements</span> and <span style="color:green">materials</span> mentions.',
unsafe_allow_html=True)

st.divider()

st.header("Documentation")
st.markdown("https://github.com/lfoppiano/document-qa")
st.markdown(
"""After entering your API Key (Open AI or Huggingface). Upload a scientific article as PDF document. You will see a spinner or loading indicator while the processing is in progress. Once the spinner stops, you can proceed to ask your questions.""")

st.markdown(
'**NER on LLM responses**: The responses from the LLMs are post-processed to extract <span style="color:orange">physical quantities, measurements</span> and <span style="color:green">materials</span> mentions.',
unsafe_allow_html=True)
if st.session_state['git_rev'] != "unknown":
st.markdown("**Revision number**: [" + st.session_state[
'git_rev'] + "](https://github.com/lfoppiano/document-qa/commit/" + st.session_state['git_rev'] + ")")
Expand All @@ -203,9 +225,9 @@ def play_old_messages():
tmp_file = NamedTemporaryFile()
tmp_file.write(bytearray(binary))
# hash = get_file_hash(tmp_file.name)[:10]
st.session_state['doc_id'] = hash = st.session_state['rqa'].create_memory_embeddings(tmp_file.name,
chunk_size=250,
perc_overlap=0.1)
st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(tmp_file.name,
chunk_size=chunk_size,
perc_overlap=0.1)
st.session_state['loaded_embeddings'] = True
st.session_state.messages = []

Expand All @@ -226,27 +248,26 @@ def play_old_messages():
text_response = None
if mode == "Embeddings":
with st.spinner("Generating LLM response..."):
text_response = st.session_state['rqa'].query_storage(question, st.session_state.doc_id,
context_size=context_size)
text_response = st.session_state['rqa'][model].query_storage(question, st.session_state.doc_id,
context_size=context_size)
elif mode == "LLM":
with st.spinner("Generating response..."):
_, text_response = st.session_state['rqa'].query_document(question, st.session_state.doc_id,
context_size=context_size)
_, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
context_size=context_size)

if not text_response:
st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")

with st.chat_message("assistant"):
if mode == "LLM":
with st.spinner("Processing NER on LLM response..."):
entities = gqa.process_single_text(text_response)
# for entity in entities:
# entity
decorated_text = decorate_text_with_annotations(text_response.strip(), entities)
decorated_text = decorated_text.replace('class="label material"', 'style="color:green"')
decorated_text = re.sub(r'class="label[^"]+"', 'style="color:orange"', decorated_text)
st.markdown(decorated_text, unsafe_allow_html=True)
text_response = decorated_text
if st.session_state['ner_processing']:
with st.spinner("Processing NER on LLM response..."):
entities = gqa.process_single_text(text_response)
decorated_text = decorate_text_with_annotations(text_response.strip(), entities)
decorated_text = decorated_text.replace('class="label material"', 'style="color:green"')
decorated_text = re.sub(r'class="label[^"]+"', 'style="color:orange"', decorated_text)
text_response = decorated_text
st.markdown(text_response, unsafe_allow_html=True)
else:
st.write(text_response)
st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
Expand Down

0 comments on commit 88c1cba

Please sign in to comment.