-
Notifications
You must be signed in to change notification settings - Fork 1
/
app.py
69 lines (59 loc) · 1.86 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from langchain import hub
from langchain.embeddings import GPT4AllEmbeddings
from langchain.vectorstores import Chroma
from langchain.llms import Ollama
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
import chainlit as cl
from langchain.chains import RetrievalQA,RetrievalQAWithSourcesChain
# Set up RetrievelQA model
QA_CHAIN_PROMPT = hub.pull("rlm/rag-prompt-mistral")
#load the LLM
def load_llm():
llm = Ollama(
model="mistral",
verbose=True,
callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),
)
return llm
def retrieval_qa_chain(llm,vectorstore):
qa_chain = RetrievalQA.from_chain_type(
llm,
retriever=vectorstore.as_retriever(),
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT},
return_source_documents=True,
)
return qa_chain
def qa_bot():
llm=load_llm()
DB_PATH = "vdb/"
vectorstore = Chroma(persist_directory=DB_PATH, embedding_function=GPT4AllEmbeddings())
qa = retrieval_qa_chain(llm,vectorstore)
return qa
@cl.on_chat_start
async def start():
chain=qa_bot()
msg=cl.Message(content="Firing up the research info bot...")
await msg.send()
msg.content= "Hi, welcome to running injury info bot. What is your query?"
await msg.update()
cl.user_session.set("chain",chain)
@cl.on_message
async def main(message):
chain=cl.user_session.get("chain")
cb = cl.AsyncLangchainCallbackHandler(
stream_final_answer=True,
answer_prefix_tokens=["FINAL", "ANSWER"]
)
cb.answer_reached=True
# res=await chain.acall(message, callbacks=[cb])
res=await chain.acall(message.content, callbacks=[cb])
print(f"response: {res}")
answer=res["result"]
answer=answer.replace(".",".\n")
sources=res["source_documents"]
if sources:
answer+=f"\nSources: "+str(str(sources))
else:
answer+=f"\nNo Sources found"
await cl.Message(content=answer).send()