From 9152e08e25920389dc0b5af7abd9ccbac9989620 Mon Sep 17 00:00:00 2001 From: Frederic Lepied Date: Thu, 24 Aug 2023 09:44:34 +0200 Subject: [PATCH] Display the document sources at the end of the answer - add a new qa.py script to ask a question on the command line --- .github/workflows/pr.yml | 5 ++- .pre-commit-config.yaml | 8 ++--- integration-test.sh | 10 ++++++ lib.py | 70 ++++++++++++++++++++++++++++++++++++++-- qa.py | 15 +++++++++ second_brain_agent.py | 44 ++++++++----------------- 6 files changed, 114 insertions(+), 38 deletions(-) create mode 100755 qa.py diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 06ea9ec..8c04042 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -24,9 +24,12 @@ jobs: curl -sSL https://install.python-poetry.org | python3 - poetry install --with test poetry run pip3 install torch - poetry run pre-commit run -a --show-diff-on-failure + poetry run pre-commit run -a --show-diff-on-failure -v - name: Run integration tests + env: + HUGGINGFACEHUB_API_TOKEN: ${{ secrets.HUGGINGFACEHUB_API_TOKEN }} + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} run: | set -ex ./integration-test.sh diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 30f808c..f84257e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -75,10 +75,10 @@ repos: pass_filenames: false files: ^(.*/)?pyproject.toml$ - - id: poetry-lock - name: Poetry lock - description: run poetry lock to check consistency - entry: poetry lock --check + - id: poetry-check-lock + name: Poetry check lock + description: run poetry check for lock consistency + entry: poetry check --lock language: python pass_filenames: false files: ^(.*/)?pyproject.toml$ diff --git a/integration-test.sh b/integration-test.sh index 360e0d6..3d7c183 100755 --- a/integration-test.sh +++ b/integration-test.sh @@ -38,8 +38,18 @@ done sudo journalctl -u sba-md sudo journalctl -u sba-txt +set +x + +# test the vector store RES=$(poetry run ./similarity.py "What is langchain?") echo "$RES" test -n "$RES" +# test the vector store and llm +RES=$(poetry run ./qa.py "What is langchain?") +echo "$RES" +if grep -q "I don't know." <<< "$RES"; then + exit 1 +fi + # integration-test.sh ends here diff --git a/lib.py b/lib.py index e017cf5..0219514 100644 --- a/lib.py +++ b/lib.py @@ -6,13 +6,15 @@ import time import chromadb +from langchain import OpenAI +from langchain.chains import RetrievalQAWithSourcesChain from langchain.embeddings import HuggingFaceEmbeddings from langchain.indexes.vectorstore import VectorStoreIndexWrapper from langchain.vectorstores import Chroma def cleanup_text(text): - """Clean up tetextt. + """Clean up text - remove urls - remove hashtag @@ -68,7 +70,7 @@ def get_indexer(): def is_same_time(fname, oname): - "compare if {fname} and {oname} have the same timestamp" + "Compare if {fname} and {oname} have the same timestamp" ftime = os.stat(fname).st_mtime # do not write if the timestamps are the same try: @@ -80,4 +82,68 @@ def is_same_time(fname, oname): return False +def local_link(path): + "Create a local link to a file" + if path.startswith("/"): + return f"file:{path}" + return path + + +class Agent: + "Agent to answer questions" + + def __init__(self): + "Initialize the agent" + self.vectorstore = get_vectorstore() + self.chain = RetrievalQAWithSourcesChain.from_llm( + llm=OpenAI(temperature=0), + retriever=self.vectorstore.as_retriever(), + ) + + def question(self, user_question): + "Ask a question and format the answer for text" + response = self._get_response(user_question) + if response["sources"] != "None.": + sources = "- " + "\n- ".join(self._get_real_sources(response["sources"])) + return f"{response['answer']}\nSources:\n{sources}" + return response["answer"] + + def html_question(self, user_question): + "Ask a question and format the answer for html" + response = self._get_response(user_question) + if response["sources"] != "None.": + sources = "- " + "\n- ".join( + [ + f'{src}' + for src in self._get_real_sources(response["sources"]) + ] + ) + return f"{response['answer']}\nSources:\n{sources}" + return response["answer"] + + def _get_response(self, user_question): + "Get the response from the LLM and vector store" + return self.chain({"question": user_question}) + + def _get_real_sources(self, sources): + "Get the url instead of the chunk sources" + real_sources = [] + for source in sources.split(", "): + results = self.vectorstore.get( + include=["metadatas"], where={"source": source} + ) + if ( + results + and "metadatas" in results + and len(results["metadatas"]) > 0 + and "url" in results["metadatas"][0] + ): + url = results["metadatas"][0]["url"] + if url not in real_sources: + real_sources.append(url) + else: + real_sources.append(source) + return real_sources + + # lib.py ends here diff --git a/qa.py b/qa.py new file mode 100755 index 0000000..bb0361f --- /dev/null +++ b/qa.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python3 + +"CLI to ask questions to the agent" + +import sys + +from dotenv import load_dotenv + +import lib + +load_dotenv() +agent = lib.Agent() +print(agent.question(" ".join(sys.argv[1:]))) + +# qa.py ends here diff --git a/second_brain_agent.py b/second_brain_agent.py index b115aba..399902f 100644 --- a/second_brain_agent.py +++ b/second_brain_agent.py @@ -6,29 +6,18 @@ import streamlit as st from dotenv import load_dotenv -from langchain.chains import ConversationalRetrievalChain -from langchain.chat_models import ChatOpenAI -from langchain.memory import ConversationBufferMemory -from htmlTemplates import bot_template, css, user_template -from lib import get_vectorstore +from htmlTemplates import css, user_template +from lib import Agent def handle_userinput(user_question): "Handle the input from the user as a question to the LLM" - response = st.session_state.conversation({"question": user_question}) - st.session_state.chat_history = response["chat_history"] - - for i, message in enumerate(st.session_state.chat_history): - if i % 2 == 0: - st.write( - user_template.replace("{{MSG}}", message.content), - unsafe_allow_html=True, - ) - else: - st.write( - bot_template.replace("{{MSG}}", message.content), unsafe_allow_html=True - ) + response = st.session_state.agent.html_question(user_question) + st.write( + user_template.replace("{{MSG}}", response), + unsafe_allow_html=True, + ) def clear_input_box(): @@ -40,19 +29,14 @@ def clear_input_box(): def main(): "Entry point" load_dotenv() - st.set_page_config(page_title="Chat with your Second Brain", page_icon=":brain:") + st.set_page_config( + page_title="Ask questions to your Second Brain", page_icon=":brain:" + ) st.write(css, unsafe_allow_html=True) - st.header("Chat with your Second Brain :brain:") + st.header("Ask a question to your Second Brain :brain:") - if "conversation" not in st.session_state: - memory = ConversationBufferMemory( - memory_key="chat_history", return_messages=True - ) - st.session_state.conversation = ConversationalRetrievalChain.from_llm( - ChatOpenAI(temperature=0), - get_vectorstore().as_retriever(), - memory=memory, - ) + if "agent" not in st.session_state: + st.session_state.agent = Agent() st.text_input( "Ask a question to your second brain:", @@ -72,8 +56,6 @@ def main(): """, height=150, ) - if "chat_history" not in st.session_state: - st.session_state.chat_history = None if __name__ == "__main__":