Skip to content

Commit

Permalink
Merge pull request #18 from lfoppiano/add-memory
Browse files Browse the repository at this point in the history
Add conversational memory with sliding window
  • Loading branch information
lfoppiano authored Nov 18, 2023
2 parents 320f843 + b19b313 commit cc3e97d
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 14 deletions.
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@ license: apache-2.0

## Introduction

Question/Answering on scientific documents using LLMs (OpenAI, Mistral, ~~LLama2,~~ etc..).
This application is the frontend for testing the RAG (Retrieval Augmented Generation) on scientific documents, that we are developing at NIMS.
Differently to most of the project, we focus on scientific articles. We target only the full-text using [Grobid](https://github.com/kermitt2/grobid) that provide and cleaner results than the raw PDF2Text converter (which is comparable with most of other solutions).
Question/Answering on scientific documents using LLMs: ChatGPT-3.5-turbo, Mistral-7b-instruct and Zephyr-7b-beta.
The streamlit application demonstrate the implementaiton of a RAG (Retrieval Augmented Generation) on scientific documents, that we are developing at NIMS (National Institute for Materials Science), in Tsukuba, Japan.
Differently to most of the projects, we focus on scientific articles.
We target only the full-text using [Grobid](https://github.com/kermitt2/grobid) that provide and cleaner results than the raw PDF2Text converter (which is comparable with most of other solutions).

**NER in LLM response**: The responses from the LLMs are post-processed to extract <span stype="color:yellow">physical quantities, measurements</span> (with [grobid-quantities](https://github.com/kermitt2/grobid-quantities)) and <span stype="color:blue">materials</span> mentions (with [grobid-superconductors](https://github.com/lfoppiano/grobid-superconductors)).
Additionally, this frontend provides the visualisation of named entities on LLM responses to extract <span stype="color:yellow">physical quantities, measurements</span> (with [grobid-quantities](https://github.com/kermitt2/grobid-quantities)) and <span stype="color:blue">materials</span> mentions (with [grobid-superconductors](https://github.com/lfoppiano/grobid-superconductors)).

The conversation is backed up by a sliding window memory (top 4 more recent messages) that help refers to information previously discussed in the chat.

**Demos**:
- (on HuggingFace spaces): https://lfoppiano-document-qa.hf.space/
Expand Down
30 changes: 22 additions & 8 deletions document_qa/document_qa_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@ class DocumentQAEngine:
embeddings_map_from_md5 = {}
embeddings_map_to_md5 = {}

def __init__(self, llm, embedding_function, qa_chain_type="stuff", embeddings_root_path=None, grobid_url=None):
def __init__(self,
llm,
embedding_function,
qa_chain_type="stuff",
embeddings_root_path=None,
grobid_url=None,
):
self.embedding_function = embedding_function
self.llm = llm
self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
Expand Down Expand Up @@ -81,14 +87,14 @@ def get_filename_from_md5(self, md5):
return self.embeddings_map_from_md5[md5]

def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None,
verbose=False) -> (
verbose=False, memory=None) -> (
Any, str):
# self.load_embeddings(self.embeddings_root_path)

if verbose:
print(query)

response = self._run_query(doc_id, query, context_size=context_size)
response = self._run_query(doc_id, query, context_size=context_size, memory=memory)
response = response['output_text'] if 'output_text' in response else response

if verbose:
Expand Down Expand Up @@ -138,9 +144,15 @@ def _parse_json(self, response, output_parser):

return parsed_output

def _run_query(self, doc_id, query, context_size=4):
def _run_query(self, doc_id, query, memory=None, context_size=4):
relevant_documents = self._get_context(doc_id, query, context_size)
return self.chain.run(input_documents=relevant_documents, question=query)
if memory:
return self.chain.run(input_documents=relevant_documents,
question=query)
else:
return self.chain.run(input_documents=relevant_documents,
question=query,
memory=memory)
# return self.chain({"input_documents": relevant_documents, "question": prompt_chat_template}, return_only_outputs=True)

def _get_context(self, doc_id, query, context_size=4):
Expand All @@ -150,6 +162,7 @@ def _get_context(self, doc_id, query, context_size=4):
return relevant_documents

def get_all_context_by_document(self, doc_id):
"""Return the full context from the document"""
db = self.embeddings_dict[doc_id]
docs = db.get()
return docs['documents']
Expand All @@ -161,6 +174,7 @@ def _get_context_multiquery(self, doc_id, query, context_size=4):
return relevant_documents

def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
"""Extract text from documents using Grobid, if chunk_size is < 0 it keep each paragraph separately"""
if verbose:
print("File", pdf_file_path)
filename = Path(pdf_file_path).stem
Expand Down Expand Up @@ -215,12 +229,11 @@ def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_o
self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
collection_name=hash)


self.embeddings_root_path = None

return hash

def create_embeddings(self, pdfs_dir_path: Path):
def create_embeddings(self, pdfs_dir_path: Path, chunk_size=500, perc_overlap=0.1):
input_files = []
for root, dirs, files in os.walk(pdfs_dir_path, followlinks=False):
for file_ in files:
Expand All @@ -238,7 +251,8 @@ def create_embeddings(self, pdfs_dir_path: Path):
print(data_path, "exists. Skipping it ")
continue

texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=500, perc_overlap=0.1)
texts, metadata, ids = self.get_text_from_document(input_file, chunk_size=chunk_size,
perc_overlap=perc_overlap)
filename = metadata[0]['filename']

vector_db_document = Chroma.from_texts(texts,
Expand Down
26 changes: 24 additions & 2 deletions streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dotenv
from grobid_quantities.quantities import QuantitiesAPI
from langchain.llms.huggingface_hub import HuggingFaceHub
from langchain.memory import ConversationBufferWindowMemory

dotenv.load_dotenv(override=True)

Expand Down Expand Up @@ -51,6 +52,9 @@
if 'uploaded' not in st.session_state:
st.session_state['uploaded'] = False

if 'memory' not in st.session_state:
st.session_state['memory'] = ConversationBufferWindowMemory(k=4)

st.set_page_config(
page_title="Scientific Document Insights Q/A",
page_icon="📝",
Expand All @@ -67,6 +71,11 @@ def new_file():
st.session_state['loaded_embeddings'] = None
st.session_state['doc_id'] = None
st.session_state['uploaded'] = True
st.session_state['memory'].clear()


def clear_memory():
st.session_state['memory'].clear()


# @st.cache_resource
Expand Down Expand Up @@ -97,6 +106,7 @@ def init_qa(model, api_key=None):
else:
st.error("The model was not loaded properly. Try reloading. ")
st.stop()
return

return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])

Expand Down Expand Up @@ -168,7 +178,7 @@ def play_old_messages():
disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded'])

st.markdown(
":warning: Mistral and Zephyr are free to use, however requests might hit limits of the huggingface free API and fail. :warning: ")
":warning: Mistral and Zephyr are **FREE** to use. Requests might fail anytime. Use at your own risk. :warning: ")

if (model == 'mistral-7b-instruct-v0.1' or model == 'zephyr-7b-beta') and model not in st.session_state['api_keys']:
if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
Expand Down Expand Up @@ -205,6 +215,11 @@ def play_old_messages():
# else:
# is_api_key_provided = st.session_state['api_key']

st.button(
'Reset chat memory.',
on_click=clear_memory(),
help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.")

st.title("📝 Scientific Document Insights Q/A")
st.subheader("Upload a scientific article in PDF, ask questions, get insights.")

Expand Down Expand Up @@ -297,7 +312,8 @@ def play_old_messages():
elif mode == "LLM":
with st.spinner("Generating response..."):
_, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
context_size=context_size)
context_size=context_size,
memory=st.session_state.memory)

if not text_response:
st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")
Expand All @@ -316,5 +332,11 @@ def play_old_messages():
st.write(text_response)
st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})

for id in range(0, len(st.session_state.messages), 2):
question = st.session_state.messages[id]['content']
if len(st.session_state.messages) > id + 1:
answer = st.session_state.messages[id + 1]['content']
st.session_state.memory.save_context({"input": question}, {"output": answer})

elif st.session_state.loaded_embeddings and st.session_state.doc_id:
play_old_messages()

0 comments on commit cc3e97d

Please sign in to comment.