Skip to content

Commit

Permalink
Merge pull request #20 from lfoppiano/fix-conversational-memory
Browse files Browse the repository at this point in the history
Fix conversational memory
  • Loading branch information
lfoppiano authored Nov 22, 2023
2 parents 55e39a2 + 16cf398 commit e7425e5
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 23 deletions.
49 changes: 34 additions & 15 deletions document_qa/document_qa_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

from document_qa.grobid_processors import GrobidProcessor
from grobid_client.grobid_client import GrobidClient
from langchain.chains import create_extraction_chain
from langchain.chains.question_answering import load_qa_chain
from langchain.chains import create_extraction_chain, ConversationChain, ConversationalRetrievalChain
from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \
map_rerank_prompt
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.retrievers import MultiQueryRetriever
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from tqdm import tqdm
Expand All @@ -22,15 +24,28 @@ class DocumentQAEngine:
embeddings_map_from_md5 = {}
embeddings_map_to_md5 = {}

default_prompts = {
'stuff': stuff_prompt,
'refine': refine_prompts,
"map_reduce": map_reduce_prompt,
"map_rerank": map_rerank_prompt
}

def __init__(self,
llm,
embedding_function,
qa_chain_type="stuff",
embeddings_root_path=None,
grobid_url=None,
memory=None
):
self.embedding_function = embedding_function
self.llm = llm
# if memory:
# prompt = self.default_prompts[qa_chain_type].PROMPT_SELECTOR.get_prompt(llm)
# self.chain = load_qa_chain(llm, chain_type=qa_chain_type, prompt=prompt, memory=memory)
# else:
self.memory = memory
self.chain = load_qa_chain(llm, chain_type=qa_chain_type)

if embeddings_root_path is not None:
Expand Down Expand Up @@ -86,14 +101,14 @@ def get_filename_from_md5(self, md5):
return self.embeddings_map_from_md5[md5]

def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None,
verbose=False, memory=None) -> (
verbose=False) -> (
Any, str):
# self.load_embeddings(self.embeddings_root_path)

if verbose:
print(query)

response = self._run_query(doc_id, query, context_size=context_size, memory=memory)
response = self._run_query(doc_id, query, context_size=context_size)
response = response['output_text'] if 'output_text' in response else response

if verbose:
Expand Down Expand Up @@ -143,21 +158,21 @@ def _parse_json(self, response, output_parser):

return parsed_output

def _run_query(self, doc_id, query, context_size=4, memory=None):
def _run_query(self, doc_id, query, context_size=4):
relevant_documents = self._get_context(doc_id, query, context_size)
if memory:
return self.chain.run(input_documents=relevant_documents,
question=query)
else:
return self.chain.run(input_documents=relevant_documents,
question=query,
memory=memory)
# return self.chain({"input_documents": relevant_documents, "question": prompt_chat_template}, return_only_outputs=True)
response = self.chain.run(input_documents=relevant_documents,
question=query)

if self.memory:
self.memory.save_context({"input": query}, {"output": response})
return response

def _get_context(self, doc_id, query, context_size=4):
db = self.embeddings_dict[doc_id]
retriever = db.as_retriever(search_kwargs={"k": context_size})
relevant_documents = retriever.get_relevant_documents(query)
if self.memory and len(self.memory.buffer_as_messages) > 0:
relevant_documents.append(Document(page_content="Previous conversation:\n{}\n\n".format(self.memory.buffer_as_str)))
return relevant_documents

def get_all_context_by_document(self, doc_id):
Expand Down Expand Up @@ -239,11 +254,15 @@ def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_o
hash = metadata[0]['hash']

if hash not in self.embeddings_dict.keys():
self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
self.embeddings_dict[hash] = Chroma.from_texts(texts,
embedding=self.embedding_function,
metadatas=metadata,
collection_name=hash)
else:
self.embeddings_dict[hash].delete(ids=self.embeddings_dict[hash].get()['ids'])
self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata,
self.embeddings_dict[hash] = Chroma.from_texts(texts,
embedding=self.embedding_function,
metadatas=metadata,
collection_name=hash)

self.embeddings_root_path = None
Expand Down
17 changes: 9 additions & 8 deletions streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import dotenv
from grobid_quantities.quantities import QuantitiesAPI
from langchain.callbacks import PromptLayerCallbackHandler
from langchain.llms.huggingface_hub import HuggingFaceHub
from langchain.memory import ConversationBufferWindowMemory

Expand Down Expand Up @@ -80,6 +81,7 @@ def clear_memory():

# @st.cache_resource
def init_qa(model, api_key=None):
## For debug add: callbacks=[PromptLayerCallbackHandler(pl_tags=["langchain", "chatgpt", "document-qa"])])
if model == 'chatgpt-3.5-turbo':
if api_key:
chat = ChatOpenAI(model_name="gpt-3.5-turbo",
Expand Down Expand Up @@ -108,7 +110,7 @@ def init_qa(model, api_key=None):
st.stop()
return

return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'], memory=st.session_state['memory'])


@st.cache_resource
Expand Down Expand Up @@ -316,8 +318,7 @@ def play_old_messages():
elif mode == "LLM":
with st.spinner("Generating response..."):
_, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
context_size=context_size,
memory=st.session_state.memory)
context_size=context_size)

if not text_response:
st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")
Expand All @@ -336,11 +337,11 @@ def play_old_messages():
st.write(text_response)
st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})

for id in range(0, len(st.session_state.messages), 2):
question = st.session_state.messages[id]['content']
if len(st.session_state.messages) > id + 1:
answer = st.session_state.messages[id + 1]['content']
st.session_state.memory.save_context({"input": question}, {"output": answer})
# if len(st.session_state.messages) > 1:
# last_answer = st.session_state.messages[len(st.session_state.messages)-1]
# if last_answer['role'] == "assistant":
# last_question = st.session_state.messages[len(st.session_state.messages)-2]
# st.session_state.memory.save_context({"input": last_question['content']}, {"output": last_answer['content']})

elif st.session_state.loaded_embeddings and st.session_state.doc_id:
play_old_messages()

0 comments on commit e7425e5

Please sign in to comment.