Skip to content

Commit

Permalink
Merge pull request #21 from flepied/sources
Browse files Browse the repository at this point in the history
Display the document sources at the end of the answer
  • Loading branch information
flepied authored Aug 25, 2023
2 parents 9068272 + 9152e08 commit ef3e625
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 38 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@ jobs:
curl -sSL https://install.python-poetry.org | python3 -
poetry install --with test
poetry run pip3 install torch
poetry run pre-commit run -a --show-diff-on-failure
poetry run pre-commit run -a --show-diff-on-failure -v
- name: Run integration tests
env:
HUGGINGFACEHUB_API_TOKEN: ${{ secrets.HUGGINGFACEHUB_API_TOKEN }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
set -ex
./integration-test.sh
Expand Down
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ repos:
pass_filenames: false
files: ^(.*/)?pyproject.toml$

- id: poetry-lock
name: Poetry lock
description: run poetry lock to check consistency
entry: poetry lock --check
- id: poetry-check-lock
name: Poetry check lock
description: run poetry check for lock consistency
entry: poetry check --lock
language: python
pass_filenames: false
files: ^(.*/)?pyproject.toml$
Expand Down
10 changes: 10 additions & 0 deletions integration-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,18 @@ done
sudo journalctl -u sba-md
sudo journalctl -u sba-txt

set +x

# test the vector store
RES=$(poetry run ./similarity.py "What is langchain?")
echo "$RES"
test -n "$RES"

# test the vector store and llm
RES=$(poetry run ./qa.py "What is langchain?")
echo "$RES"
if grep -q "I don't know." <<< "$RES"; then
exit 1
fi

# integration-test.sh ends here
70 changes: 68 additions & 2 deletions lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import time

import chromadb
from langchain import OpenAI
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.indexes.vectorstore import VectorStoreIndexWrapper
from langchain.vectorstores import Chroma


def cleanup_text(text):
"""Clean up tetextt.
"""Clean up text
- remove urls
- remove hashtag
Expand Down Expand Up @@ -68,7 +70,7 @@ def get_indexer():


def is_same_time(fname, oname):
"compare if {fname} and {oname} have the same timestamp"
"Compare if {fname} and {oname} have the same timestamp"
ftime = os.stat(fname).st_mtime
# do not write if the timestamps are the same
try:
Expand All @@ -80,4 +82,68 @@ def is_same_time(fname, oname):
return False


def local_link(path):
"Create a local link to a file"
if path.startswith("/"):
return f"file:{path}"
return path


class Agent:
"Agent to answer questions"

def __init__(self):
"Initialize the agent"
self.vectorstore = get_vectorstore()
self.chain = RetrievalQAWithSourcesChain.from_llm(
llm=OpenAI(temperature=0),
retriever=self.vectorstore.as_retriever(),
)

def question(self, user_question):
"Ask a question and format the answer for text"
response = self._get_response(user_question)
if response["sources"] != "None.":
sources = "- " + "\n- ".join(self._get_real_sources(response["sources"]))
return f"{response['answer']}\nSources:\n{sources}"
return response["answer"]

def html_question(self, user_question):
"Ask a question and format the answer for html"
response = self._get_response(user_question)
if response["sources"] != "None.":
sources = "- " + "\n- ".join(
[
f'<a href="{local_link(src)}">{src}</a>'
for src in self._get_real_sources(response["sources"])
]
)
return f"{response['answer']}\nSources:\n{sources}"
return response["answer"]

def _get_response(self, user_question):
"Get the response from the LLM and vector store"
return self.chain({"question": user_question})

def _get_real_sources(self, sources):
"Get the url instead of the chunk sources"
real_sources = []
for source in sources.split(", "):
results = self.vectorstore.get(
include=["metadatas"], where={"source": source}
)
if (
results
and "metadatas" in results
and len(results["metadatas"]) > 0
and "url" in results["metadatas"][0]
):
url = results["metadatas"][0]["url"]
if url not in real_sources:
real_sources.append(url)
else:
real_sources.append(source)
return real_sources


# lib.py ends here
15 changes: 15 additions & 0 deletions qa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/usr/bin/env python3

"CLI to ask questions to the agent"

import sys

from dotenv import load_dotenv

import lib

load_dotenv()
agent = lib.Agent()
print(agent.question(" ".join(sys.argv[1:])))

# qa.py ends here
44 changes: 13 additions & 31 deletions second_brain_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,18 @@

import streamlit as st
from dotenv import load_dotenv
from langchain.chains import ConversationalRetrievalChain
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory

from htmlTemplates import bot_template, css, user_template
from lib import get_vectorstore
from htmlTemplates import css, user_template
from lib import Agent


def handle_userinput(user_question):
"Handle the input from the user as a question to the LLM"
response = st.session_state.conversation({"question": user_question})
st.session_state.chat_history = response["chat_history"]

for i, message in enumerate(st.session_state.chat_history):
if i % 2 == 0:
st.write(
user_template.replace("{{MSG}}", message.content),
unsafe_allow_html=True,
)
else:
st.write(
bot_template.replace("{{MSG}}", message.content), unsafe_allow_html=True
)
response = st.session_state.agent.html_question(user_question)
st.write(
user_template.replace("{{MSG}}", response),
unsafe_allow_html=True,
)


def clear_input_box():
Expand All @@ -40,19 +29,14 @@ def clear_input_box():
def main():
"Entry point"
load_dotenv()
st.set_page_config(page_title="Chat with your Second Brain", page_icon=":brain:")
st.set_page_config(
page_title="Ask questions to your Second Brain", page_icon=":brain:"
)
st.write(css, unsafe_allow_html=True)
st.header("Chat with your Second Brain :brain:")
st.header("Ask a question to your Second Brain :brain:")

if "conversation" not in st.session_state:
memory = ConversationBufferMemory(
memory_key="chat_history", return_messages=True
)
st.session_state.conversation = ConversationalRetrievalChain.from_llm(
ChatOpenAI(temperature=0),
get_vectorstore().as_retriever(),
memory=memory,
)
if "agent" not in st.session_state:
st.session_state.agent = Agent()

st.text_input(
"Ask a question to your second brain:",
Expand All @@ -72,8 +56,6 @@ def main():
""",
height=150,
)
if "chat_history" not in st.session_state:
st.session_state.chat_history = None


if __name__ == "__main__":
Expand Down

0 comments on commit ef3e625

Please sign in to comment.