From 37619a9e0e0cd02362263f61ec90bf92361bc41c Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Wed, 22 Nov 2023 09:14:01 +0900 Subject: [PATCH 1/3] fix signature --- document_qa/document_qa_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/document_qa/document_qa_engine.py b/document_qa/document_qa_engine.py index 196e84c..be90829 100644 --- a/document_qa/document_qa_engine.py +++ b/document_qa/document_qa_engine.py @@ -144,7 +144,7 @@ def _parse_json(self, response, output_parser): return parsed_output - def _run_query(self, doc_id, query, memory=None, context_size=4): + def _run_query(self, doc_id, query, context_size=4, memory=None): relevant_documents = self._get_context(doc_id, query, context_size) if memory: return self.chain.run(input_documents=relevant_documents, From d67901d0b258618a62a4f206b4020abd9170c03a Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Wed, 22 Nov 2023 09:14:29 +0900 Subject: [PATCH 2/3] fix memory wrongly reset at every reload --- streamlit_app.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/streamlit_app.py b/streamlit_app.py index 8f5b172..4541ec0 100644 --- a/streamlit_app.py +++ b/streamlit_app.py @@ -217,7 +217,8 @@ def play_old_messages(): st.button( 'Reset chat memory.', - on_click=clear_memory(), + key="reset-memory-button", + on_click=clear_memory, help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.") st.title("📝 Scientific Document Insights Q/A") @@ -226,7 +227,9 @@ def play_old_messages(): st.markdown( ":warning: Do not upload sensitive data. We **temporarily** store text from the uploaded PDF documents solely for the purpose of processing your request, and we **do not assume responsibility** for any subsequent use or handling of the data submitted to third parties LLMs.") -uploaded_file = st.file_uploader("Upload an article", type=("pdf", "txt"), on_change=new_file, +uploaded_file = st.file_uploader("Upload an article", + type=("pdf", "txt"), + on_change=new_file, disabled=st.session_state['model'] is not None and st.session_state['model'] not in st.session_state['api_keys'], help="The full-text is extracted using Grobid. ") From ad304e22fe854ccf5d77d4e974960a255523bc28 Mon Sep 17 00:00:00 2001 From: Luca Foppiano Date: Wed, 22 Nov 2023 14:58:42 +0900 Subject: [PATCH 3/3] re-implement the conversational memory access --- document_qa/document_qa_engine.py | 49 +++++++++++++++++++++---------- streamlit_app.py | 17 ++++++----- 2 files changed, 43 insertions(+), 23 deletions(-) diff --git a/document_qa/document_qa_engine.py b/document_qa/document_qa_engine.py index be90829..f91fc6d 100644 --- a/document_qa/document_qa_engine.py +++ b/document_qa/document_qa_engine.py @@ -4,10 +4,12 @@ from typing import Union, Any from grobid_client.grobid_client import GrobidClient -from langchain.chains import create_extraction_chain -from langchain.chains.question_answering import load_qa_chain +from langchain.chains import create_extraction_chain, ConversationChain, ConversationalRetrievalChain +from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \ + map_rerank_prompt from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate from langchain.retrievers import MultiQueryRetriever +from langchain.schema import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import Chroma from tqdm import tqdm @@ -23,15 +25,28 @@ class DocumentQAEngine: embeddings_map_from_md5 = {} embeddings_map_to_md5 = {} + default_prompts = { + 'stuff': stuff_prompt, + 'refine': refine_prompts, + "map_reduce": map_reduce_prompt, + "map_rerank": map_rerank_prompt + } + def __init__(self, llm, embedding_function, qa_chain_type="stuff", embeddings_root_path=None, grobid_url=None, + memory=None ): self.embedding_function = embedding_function self.llm = llm + # if memory: + # prompt = self.default_prompts[qa_chain_type].PROMPT_SELECTOR.get_prompt(llm) + # self.chain = load_qa_chain(llm, chain_type=qa_chain_type, prompt=prompt, memory=memory) + # else: + self.memory = memory self.chain = load_qa_chain(llm, chain_type=qa_chain_type) if embeddings_root_path is not None: @@ -87,14 +102,14 @@ def get_filename_from_md5(self, md5): return self.embeddings_map_from_md5[md5] def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None, - verbose=False, memory=None) -> ( + verbose=False) -> ( Any, str): # self.load_embeddings(self.embeddings_root_path) if verbose: print(query) - response = self._run_query(doc_id, query, context_size=context_size, memory=memory) + response = self._run_query(doc_id, query, context_size=context_size) response = response['output_text'] if 'output_text' in response else response if verbose: @@ -144,21 +159,21 @@ def _parse_json(self, response, output_parser): return parsed_output - def _run_query(self, doc_id, query, context_size=4, memory=None): + def _run_query(self, doc_id, query, context_size=4): relevant_documents = self._get_context(doc_id, query, context_size) - if memory: - return self.chain.run(input_documents=relevant_documents, - question=query) - else: - return self.chain.run(input_documents=relevant_documents, - question=query, - memory=memory) - # return self.chain({"input_documents": relevant_documents, "question": prompt_chat_template}, return_only_outputs=True) + response = self.chain.run(input_documents=relevant_documents, + question=query) + + if self.memory: + self.memory.save_context({"input": query}, {"output": response}) + return response def _get_context(self, doc_id, query, context_size=4): db = self.embeddings_dict[doc_id] retriever = db.as_retriever(search_kwargs={"k": context_size}) relevant_documents = retriever.get_relevant_documents(query) + if self.memory and len(self.memory.buffer_as_messages) > 0: + relevant_documents.append(Document(page_content="Previous conversation:\n{}\n\n".format(self.memory.buffer_as_str))) return relevant_documents def get_all_context_by_document(self, doc_id): @@ -222,11 +237,15 @@ def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_o hash = metadata[0]['hash'] if hash not in self.embeddings_dict.keys(): - self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata, + self.embeddings_dict[hash] = Chroma.from_texts(texts, + embedding=self.embedding_function, + metadatas=metadata, collection_name=hash) else: self.embeddings_dict[hash].delete(ids=self.embeddings_dict[hash].get()['ids']) - self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata, + self.embeddings_dict[hash] = Chroma.from_texts(texts, + embedding=self.embedding_function, + metadatas=metadata, collection_name=hash) self.embeddings_root_path = None diff --git a/streamlit_app.py b/streamlit_app.py index 4541ec0..f5a9f60 100644 --- a/streamlit_app.py +++ b/streamlit_app.py @@ -5,6 +5,7 @@ import dotenv from grobid_quantities.quantities import QuantitiesAPI +from langchain.callbacks import PromptLayerCallbackHandler from langchain.llms.huggingface_hub import HuggingFaceHub from langchain.memory import ConversationBufferWindowMemory @@ -80,6 +81,7 @@ def clear_memory(): # @st.cache_resource def init_qa(model, api_key=None): + ## For debug add: callbacks=[PromptLayerCallbackHandler(pl_tags=["langchain", "chatgpt", "document-qa"])]) if model == 'chatgpt-3.5-turbo': if api_key: chat = ChatOpenAI(model_name="gpt-3.5-turbo", @@ -108,7 +110,7 @@ def init_qa(model, api_key=None): st.stop() return - return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL']) + return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'], memory=st.session_state['memory']) @st.cache_resource @@ -315,8 +317,7 @@ def play_old_messages(): elif mode == "LLM": with st.spinner("Generating response..."): _, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id, - context_size=context_size, - memory=st.session_state.memory) + context_size=context_size) if not text_response: st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.") @@ -335,11 +336,11 @@ def play_old_messages(): st.write(text_response) st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response}) - for id in range(0, len(st.session_state.messages), 2): - question = st.session_state.messages[id]['content'] - if len(st.session_state.messages) > id + 1: - answer = st.session_state.messages[id + 1]['content'] - st.session_state.memory.save_context({"input": question}, {"output": answer}) + # if len(st.session_state.messages) > 1: + # last_answer = st.session_state.messages[len(st.session_state.messages)-1] + # if last_answer['role'] == "assistant": + # last_question = st.session_state.messages[len(st.session_state.messages)-2] + # st.session_state.memory.save_context({"input": last_question['content']}, {"output": last_answer['content']}) elif st.session_state.loaded_embeddings and st.session_state.doc_id: play_old_messages()