-
Notifications
You must be signed in to change notification settings - Fork 2
/
app.py
116 lines (96 loc) · 4.69 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import pymongo
import os, textwrap
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
# from langchain_community.vectorstores import FAISS
from langchain_community.llms import HuggingFaceHub
# from langchain_community.document_loaders import PyPDFLoader
# from langchain_community.document_loaders import DirectoryLoader
# from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import MongoDBAtlasVectorSearch
from langchain_community.embeddings import HuggingFaceInstructEmbeddings
from flask import Flask, request, render_template
from flask_cors import CORS
from gevent import pywsgi
app = Flask(__name__)
CORS(app)
# load env
mongodb_connection_string = os.getenv("MONGODB_CONNECTION_STRING")
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
# Initialize the models
instructor_embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-xl")
llm=HuggingFaceHub(repo_id="mistralai/Mistral-7B-Instruct-v0.2", model_kwargs={"temperature":0.1 ,"max_length":512})
@app.route('/',methods=['GET','POST'])
def main():
if request.method == 'POST':
# connect to mongodb
client = pymongo.MongoClient(mongodb_connection_string)
db = client.database
collection = db.textbooks
query = request.args.get('q')
# query = unquote(query)
print("==================== query is -",query,"====================")
# query = 'What is the price of iphone 13?'
# load pdfs from the Documents directory
# loader = DirectoryLoader(f'./Documents/', glob="./*.pdf", loader_cls=PyPDFLoader)
# documents = loader.load()
# split the documents into chunks
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
# texts = text_splitter.split_documents(documents)
# create the retriever
# db_instructEmbedd = FAISS.from_documents(texts, instructor_embeddings)
# retriever = db_instructEmbedd.as_retriever(search_kwargs={"k": 3})
# retriever search type is similarity search
# # create the retriever and do embedding
# vector_search = MongoDBAtlasVectorSearch.from_documents(
# documents=texts,
# embedding=instructor_embeddings,
# collection=collection,
# index_name="default",
# )
vector_search = MongoDBAtlasVectorSearch.from_connection_string(
mongodb_connection_string,
"database" + "." + "textbooks",
instructor_embeddings,
index_name="search",
)
retriever = vector_search.as_retriever(
search_type="similarity",
search_kwargs={"k": 3},
)
# prompt template
prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
{context}
Question: {question}
"""
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
# create the chain to answer questions
qa_chain_instrucEmbed = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True, chain_type_kwargs={"prompt": PROMPT})
# def wrap_text_preserve_newlines(text, width=110):
# # Split the input text into lines based on newline characters
# lines = text.split('\n')
# # Wrap each line individually
# wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
# # Join the wrapped lines back together using newline characters
# wrapped_text = '\n'.join(wrapped_lines)
# return wrapped_text
llm_response = qa_chain_instrucEmbed(query)
# res = wrap_text_preserve_newlines(llm_response['result'])
res = llm_response['result']
source = [[item.metadata.get('source')[10:-4], item.metadata.get('page')+1] for item in llm_response['source_documents']]
print(res)
index_helpful_answer = res.find("Answer:")
if index_helpful_answer != -1:
helpful_answer_text = res[index_helpful_answer + len("Answer:"):]
# helpful_answer_text.strip().replace("\n"," ")
return({"result": helpful_answer_text, "source": source if "I don't know" not in helpful_answer_text else []})
else:
return("Error")
else:
return render_template('index.html')
if __name__ == '__main__':
server = pywsgi.WSGIServer(('0.0.0.0', 7860), app)
server.serve_forever()
# app.run(host="0.0.0.0", port=7860)