diff --git a/document_qa/document_qa_engine.py b/document_qa/document_qa_engine.py index 49e5f5f..0986460 100644 --- a/document_qa/document_qa_engine.py +++ b/document_qa/document_qa_engine.py @@ -5,10 +5,12 @@ from document_qa.grobid_processors import GrobidProcessor from grobid_client.grobid_client import GrobidClient -from langchain.chains import create_extraction_chain -from langchain.chains.question_answering import load_qa_chain +from langchain.chains import create_extraction_chain, ConversationChain, ConversationalRetrievalChain +from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \ + map_rerank_prompt from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate from langchain.retrievers import MultiQueryRetriever +from langchain.schema import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import Chroma from tqdm import tqdm @@ -22,15 +24,28 @@ class DocumentQAEngine: embeddings_map_from_md5 = {} embeddings_map_to_md5 = {} + default_prompts = { + 'stuff': stuff_prompt, + 'refine': refine_prompts, + "map_reduce": map_reduce_prompt, + "map_rerank": map_rerank_prompt + } + def __init__(self, llm, embedding_function, qa_chain_type="stuff", embeddings_root_path=None, grobid_url=None, + memory=None ): self.embedding_function = embedding_function self.llm = llm + # if memory: + # prompt = self.default_prompts[qa_chain_type].PROMPT_SELECTOR.get_prompt(llm) + # self.chain = load_qa_chain(llm, chain_type=qa_chain_type, prompt=prompt, memory=memory) + # else: + self.memory = memory self.chain = load_qa_chain(llm, chain_type=qa_chain_type) if embeddings_root_path is not None: @@ -86,14 +101,14 @@ def get_filename_from_md5(self, md5): return self.embeddings_map_from_md5[md5] def query_document(self, query: str, doc_id, output_parser=None, context_size=4, extraction_schema=None, - verbose=False, memory=None) -> ( + verbose=False) -> ( Any, str): # self.load_embeddings(self.embeddings_root_path) if verbose: print(query) - response = self._run_query(doc_id, query, context_size=context_size, memory=memory) + response = self._run_query(doc_id, query, context_size=context_size) response = response['output_text'] if 'output_text' in response else response if verbose: @@ -143,21 +158,21 @@ def _parse_json(self, response, output_parser): return parsed_output - def _run_query(self, doc_id, query, context_size=4, memory=None): + def _run_query(self, doc_id, query, context_size=4): relevant_documents = self._get_context(doc_id, query, context_size) - if memory: - return self.chain.run(input_documents=relevant_documents, - question=query) - else: - return self.chain.run(input_documents=relevant_documents, - question=query, - memory=memory) - # return self.chain({"input_documents": relevant_documents, "question": prompt_chat_template}, return_only_outputs=True) + response = self.chain.run(input_documents=relevant_documents, + question=query) + + if self.memory: + self.memory.save_context({"input": query}, {"output": response}) + return response def _get_context(self, doc_id, query, context_size=4): db = self.embeddings_dict[doc_id] retriever = db.as_retriever(search_kwargs={"k": context_size}) relevant_documents = retriever.get_relevant_documents(query) + if self.memory and len(self.memory.buffer_as_messages) > 0: + relevant_documents.append(Document(page_content="Previous conversation:\n{}\n\n".format(self.memory.buffer_as_str))) return relevant_documents def get_all_context_by_document(self, doc_id): @@ -239,11 +254,15 @@ def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_o hash = metadata[0]['hash'] if hash not in self.embeddings_dict.keys(): - self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata, + self.embeddings_dict[hash] = Chroma.from_texts(texts, + embedding=self.embedding_function, + metadatas=metadata, collection_name=hash) else: self.embeddings_dict[hash].delete(ids=self.embeddings_dict[hash].get()['ids']) - self.embeddings_dict[hash] = Chroma.from_texts(texts, embedding=self.embedding_function, metadatas=metadata, + self.embeddings_dict[hash] = Chroma.from_texts(texts, + embedding=self.embedding_function, + metadatas=metadata, collection_name=hash) self.embeddings_root_path = None diff --git a/streamlit_app.py b/streamlit_app.py index 458fb3b..8a4118e 100644 --- a/streamlit_app.py +++ b/streamlit_app.py @@ -5,6 +5,7 @@ import dotenv from grobid_quantities.quantities import QuantitiesAPI +from langchain.callbacks import PromptLayerCallbackHandler from langchain.llms.huggingface_hub import HuggingFaceHub from langchain.memory import ConversationBufferWindowMemory @@ -80,6 +81,7 @@ def clear_memory(): # @st.cache_resource def init_qa(model, api_key=None): + ## For debug add: callbacks=[PromptLayerCallbackHandler(pl_tags=["langchain", "chatgpt", "document-qa"])]) if model == 'chatgpt-3.5-turbo': if api_key: chat = ChatOpenAI(model_name="gpt-3.5-turbo", @@ -108,7 +110,7 @@ def init_qa(model, api_key=None): st.stop() return - return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL']) + return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'], memory=st.session_state['memory']) @st.cache_resource @@ -316,8 +318,7 @@ def play_old_messages(): elif mode == "LLM": with st.spinner("Generating response..."): _, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id, - context_size=context_size, - memory=st.session_state.memory) + context_size=context_size) if not text_response: st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.") @@ -336,11 +337,11 @@ def play_old_messages(): st.write(text_response) st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response}) - for id in range(0, len(st.session_state.messages), 2): - question = st.session_state.messages[id]['content'] - if len(st.session_state.messages) > id + 1: - answer = st.session_state.messages[id + 1]['content'] - st.session_state.memory.save_context({"input": question}, {"output": answer}) + # if len(st.session_state.messages) > 1: + # last_answer = st.session_state.messages[len(st.session_state.messages)-1] + # if last_answer['role'] == "assistant": + # last_question = st.session_state.messages[len(st.session_state.messages)-2] + # st.session_state.memory.save_context({"input": last_question['content']}, {"output": last_answer['content']}) elif st.session_state.loaded_embeddings and st.session_state.doc_id: play_old_messages()