diff --git a/README.md b/README.md index be27b28..0b030a9 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ We target only the full-text using [Grobid](https://github.com/kermitt2/grobid) Additionally, this frontend provides the visualisation of named entities on LLM responses to extract physical quantities, measurements (with [grobid-quantities](https://github.com/kermitt2/grobid-quantities)) and materials mentions (with [grobid-superconductors](https://github.com/lfoppiano/grobid-superconductors)). -The conversation is backed up by a sliding window memory (top 4 more recent messages) that help refers to information previously discussed in the chat. +The conversation is kept in memory up by a buffered sliding window memory (top 4 more recent messages) and the messages are injected in the context as "previous messages". **Demos**: - (on HuggingFace spaces): https://lfoppiano-document-qa.hf.space/ diff --git a/document_qa/document_qa_engine.py b/document_qa/document_qa_engine.py index 0986460..ac9eba2 100644 --- a/document_qa/document_qa_engine.py +++ b/document_qa/document_qa_engine.py @@ -41,10 +41,6 @@ def __init__(self, ): self.embedding_function = embedding_function self.llm = llm - # if memory: - # prompt = self.default_prompts[qa_chain_type].PROMPT_SELECTOR.get_prompt(llm) - # self.chain = load_qa_chain(llm, chain_type=qa_chain_type, prompt=prompt, memory=memory) - # else: self.memory = memory self.chain = load_qa_chain(llm, chain_type=qa_chain_type) @@ -161,7 +157,7 @@ def _parse_json(self, response, output_parser): def _run_query(self, doc_id, query, context_size=4): relevant_documents = self._get_context(doc_id, query, context_size) response = self.chain.run(input_documents=relevant_documents, - question=query) + question=query) if self.memory: self.memory.save_context({"input": query}, {"output": response}) @@ -172,7 +168,9 @@ def _get_context(self, doc_id, query, context_size=4): retriever = db.as_retriever(search_kwargs={"k": context_size}) relevant_documents = retriever.get_relevant_documents(query) if self.memory and len(self.memory.buffer_as_messages) > 0: - relevant_documents.append(Document(page_content="Previous conversation:\n{}\n\n".format(self.memory.buffer_as_str))) + relevant_documents.append( + Document(page_content="Previous conversation:\n{}\n\n".format(self.memory.buffer_as_str)) + ) return relevant_documents def get_all_context_by_document(self, doc_id): diff --git a/streamlit_app.py b/streamlit_app.py index 8a4118e..affa158 100644 --- a/streamlit_app.py +++ b/streamlit_app.py @@ -5,7 +5,6 @@ import dotenv from grobid_quantities.quantities import QuantitiesAPI -from langchain.callbacks import PromptLayerCallbackHandler from langchain.llms.huggingface_hub import HuggingFaceHub from langchain.memory import ConversationBufferWindowMemory