Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add conversational memory with sliding window #18

Merged
merged 4 commits into from
Nov 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@ license: apache-2.0

## Introduction

Question/Answering on scientific documents using LLMs (OpenAI, Mistral, ~~LLama2,~~ etc..).
This application is the frontend for testing the RAG (Retrieval Augmented Generation) on scientific documents, that we are developing at NIMS.
Differently to most of the project, we focus on scientific articles. We target only the full-text using [Grobid](https://github.com/kermitt2/grobid) that provide and cleaner results than the raw PDF2Text converter (which is comparable with most of other solutions).
Question/Answering on scientific documents using LLMs: ChatGPT-3.5-turbo, Mistral-7b-instruct and Zephyr-7b-beta.
The streamlit application demonstrate the implementaiton of a RAG (Retrieval Augmented Generation) on scientific documents, that we are developing at NIMS (National Institute for Materials Science), in Tsukuba, Japan.
Differently to most of the projects, we focus on scientific articles.
We target only the full-text using [Grobid](https://github.com/kermitt2/grobid) that provide and cleaner results than the raw PDF2Text converter (which is comparable with most of other solutions).

**NER in LLM response**: The responses from the LLMs are post-processed to extract <span stype="color:yellow">physical quantities, measurements</span> (with [grobid-quantities](https://github.com/kermitt2/grobid-quantities)) and <span stype="color:blue">materials</span> mentions (with [grobid-superconductors](https://github.com/lfoppiano/grobid-superconductors)).
Additionally, this frontend provides the visualisation of named entities on LLM responses to extract <span stype="color:yellow">physical quantities, measurements</span> (with [grobid-quantities](https://github.com/kermitt2/grobid-quantities)) and <span stype="color:blue">materials</span> mentions (with [grobid-superconductors](https://github.com/lfoppiano/grobid-superconductors)).

The conversation is backed up by a sliding window memory (top 4 more recent messages) that help refers to information previously discussed in the chat.

**Demos**:
- (on HuggingFace spaces): https://lfoppiano-document-qa.hf.space/
Expand Down
30 changes: 22 additions & 8 deletions document_qa/document_qa_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@ class DocumentQAEngine:
embeddings_map_from_md5 = {}
embeddings_map_to_md5 = {}

def __init__(self, llm, embedding_function, qa_chain_type="stuff", embeddings_root_path=None, grobid_url=None):
def __init__(self,
llm,
embedding_function,
qa_chain_type="stuff",
embeddings_root_path=None,
grobid_url=None,
):
self.embedding_function = embedding_function
self.llm = llm
self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
Expand Down Expand Up @@ -81,14 +87,14 @@ def get_filename_from_md5(self, md5):
return self.embeddings_map_from_md5[md5]

def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None,
verbose=False) -> (
verbose=False, memory=None) -> (
Any, str):
# self.load_embeddings(self.embeddings_root_path)

if verbose:
print(query)

response = self._run_query(doc_id, query, context_size=context_size)
response = self._run_query(doc_id, query, context_size=context_size, memory=memory)
response = response['output_text'] if 'output_text' in response else response

if verbose:
Expand Down Expand Up @@ -138,9 +144,15 @@ def _parse_json(self, response, output_parser):

return parsed_output

def _run_query(self, doc_id, query, context_size=4):
def _run_query(self, doc_id, query, memory=None, context_size=4):
relevant_documents = self._get_context(doc_id, query, context_size)
return self.chain.run(input_documents=relevant_documents, question=query)
if memory:
return self.chain.run(input_documents=relevant_documents,
question=query)
else:
return self.chain.run(input_documents=relevant_documents,
question=query,
memory=memory)
# return self.chain({"input_documents": relevant_documents, "question": prompt_chat_template}, return_only_outputs=True)

def _get_context(self, doc_id, query, context_size=4):
Expand All @@ -150,6 +162,7 @@ def _get_context(self, doc_id, query, context_size=4):
return relevant_documents

def get_all_context_by_document(self, doc_id):
"""Return the full context from the document"""
db = self.embeddings_dict[doc_id]
docs = db.get()
return docs['documents']
Expand All @@ -161,6 +174,7 @@ def _get_context_multiquery(self, doc_id, query, context_size=4):
return relevant_documents

def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
"""Extract text from documents using Grobid, if chunk_size is < 0 it keep each paragraph separately"""
if verbose:
print("File", pdf_file_path)
filename = Path(pdf_file_path).stem
Expand Down Expand Up @@ -215,12 +229,11 @@ def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_o
self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
collection_name=hash)


self.embeddings_root_path = None

return hash

def create_embeddings(self, pdfs_dir_path: Path):
def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0.1):
input_files = []
for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
for file_ in files:
Expand All @@ -238,7 +251,8 @@ def create_embeddings(self, pdfs_dir_path: Path):
print(data_path, "exists. Skipping it ")
continue

texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=500, perc_overlap=0.1)
texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=chunk_size,
perc_overlap=perc_overlap)
filename = metadata[0]['filename']

vector_db_document = Chroma.from_texts(texts,
Expand Down
26 changes: 24 additions & 2 deletions streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dotenv
from grobid_quantities.quantities import QuantitiesAPI
from langchain.llms.huggingface_hub import HuggingFaceHub
from langchain.memory import ConversationBufferWindowMemory

dotenv.load_dotenv(override=True)

Expand Down Expand Up @@ -51,6 +52,9 @@
if 'uploaded' not in st.session_state:
st.session_state['uploaded'] = False

if 'memory' not in st.session_state:
st.session_state['memory'] = ConversationBufferWindowMemory(k=4)

st.set_page_config(
page_title="Scientific Document Insights Q/A",
page_icon="📝",
Expand All @@ -67,6 +71,11 @@ def new_file():
st.session_state['loaded_embeddings'] = None
st.session_state['doc_id'] = None
st.session_state['uploaded'] = True
st.session_state['memory'].clear()


def clear_memory():
st.session_state['memory'].clear()


# @st.cache_resource
Expand Down Expand Up @@ -97,6 +106,7 @@ def init_qa(model, api_key=None):
else:
st.error("The model was not loaded properly. Try reloading. ")
st.stop()
return

return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])

Expand Down Expand Up @@ -168,7 +178,7 @@ def play_old_messages():
disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded'])

st.markdown(
":warning: Mistral and Zephyr are free to use, however requests might hit limits of the huggingface free API and fail. :warning: ")
":warning: Mistral and Zephyr are **FREE** to use. Requests might fail anytime. Use at your own risk. :warning: ")

if (model == 'mistral-7b-instruct-v0.1' or model == 'zephyr-7b-beta') and model not in st.session_state['api_keys']:
if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
Expand Down Expand Up @@ -205,6 +215,11 @@ def play_old_messages():
# else:
# is_api_key_provided = st.session_state['api_key']

st.button(
'Reset chat memory.',
on_click=clear_memory(),
help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.")

st.title("📝 Scientific Document Insights Q/A")
st.subheader("Upload a scientific article in PDF, ask questions, get insights.")

Expand Down Expand Up @@ -297,7 +312,8 @@ def play_old_messages():
elif mode == "LLM":
with st.spinner("Generating response..."):
_, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
context_size=context_size)
context_size=context_size,
memory=st.session_state.memory)

if not text_response:
st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")
Expand All @@ -316,5 +332,11 @@ def play_old_messages():
st.write(text_response)
st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})

for id in range(0, len(st.session_state.messages), 2):
question = st.session_state.messages[id]['content']
if len(st.session_state.messages) > id + 1:
answer = st.session_state.messages[id + 1]['content']
st.session_state.memory.save_context({"input": question}, {"output": answer})

elif st.session_state.loaded_embeddings and st.session_state.doc_id:
play_old_messages()
Loading