-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmyGPT.py
128 lines (105 loc) · 4.09 KB
/
myGPT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
from dotenv import load_dotenv
load_dotenv()
import openai
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.chat_models import ChatOpenAI
from langchain.chains import (StuffDocumentsChain,
LLMChain)
from langchain.schema import HumanMessage, AIMessage
from langchain.prompts import PromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate, MessagesPlaceholder
from langchain.callbacks.manager import trace_as_chain_group
import gradio as gr
from constants import CHROMA_SETTINGS
"""
import LANGCHAIN_API_KEY in case you encounter the error:
langsmith.utils.LangSmithUserError: API key must be provided when using hosted LangSmith API
Create here: https://smith.langchain.com/
"""
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGCHAIN_API_KEY"] = "YOUR_LANGCHAIN_API_KEY"
embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME")
persist_directory = os.environ.get('PERSIST_DIRECTORY')
# Create embeddings
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
# Create and store locally vectorstore
db = Chroma(embedding_function=embeddings,
persist_directory=persist_directory,
client_settings=CHROMA_SETTINGS)
# Set up our retriever
retriever = db.as_retriever()
# Define llm
llm = ChatOpenAI(temperature=0, openai_api_key='YOUR_OPENAI_API_KEY')
"""
Set up our chain that can answer questions based on documents:
This controls how each document will be formatted. Specifically,
it will be passed to `format_document` - see that function for more details
"""
document_prompt = PromptTemplate(
input_variables=["page_content"],
template="{page_content}"
)
document_variable_name = "context"
# The prompt here should take as an input variable the `document_variable_name`
prompt_template = """Use the following pieces of context to answer user questions.
If you don't know the answer, just say that can not found in knowledge base,
don't try to make up an answer.
--------------
{context}"""
system_prompt = SystemMessagePromptTemplate.from_template(prompt_template)
prompt = ChatPromptTemplate(
messages=[
system_prompt,
MessagesPlaceholder(variable_name="chat_history"),
HumanMessagePromptTemplate.from_template("{question}")
]
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
combine_docs_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name,
document_separator="---------"
)
"""
Set up a chain that controls how the search query for the vectorstore is generated:
This controls how the search query is generated.
Should take `chat_history` and `question` as input variables.
"""
template = """Combine the chat history and follow up question into a a search query.
Chat History:
{chat_history}
Follow up question: {question}
"""
prompt = PromptTemplate.from_template(template)
question_generator_chain = LLMChain(llm=llm, prompt=prompt)
# Function to use
def qa_response(message, history):
# Convert message history into format for the `question_generator_chain`.
convo_string = "\n\n".join([f"Human: {h}\nAssistant: {a}" for h, a in history])
# Convert message history into LangChain format for the final response chain.
messages = []
for human, ai in history:
messages.append(HumanMessage(content=human))
messages.append(AIMessage(content=ai))
# Wrap all actual calls to chains in a trace group.
with trace_as_chain_group("qa_response") as group_manager:
# Generate search query.
search_query = question_generator_chain.run(
question=message,
chat_history=convo_string,
callbacks=group_manager
)
# Retrieve relevant docs.
docs = retriever.get_relevant_documents(search_query, callbacks=group_manager)
# Answer question.
return combine_docs_chain.run(
input_documents=docs,
chat_history=messages,
question=message,
callbacks=group_manager
)
# start the app
gr.ChatInterface(qa_response).launch(share=True, debug=True)