-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
58 lines (46 loc) · 1.58 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from langchain.chains import RetrievalQA
from langchain_openai import ChatOpenAI
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
import chainlit as cl
from typing import Union
import os
def get_vector_store(db_path: str,
api_key: str):
embeddings = OpenAIEmbeddings(openai_api_key=api_key)
db = FAISS.load_local(db_path,
embeddings,
allow_dangerous_deserialization=True)
return db
db_path = "vector_store/"
# Get the API key for the OpenAI model
api_key = os.environ.get("OPENAI_API_KEY")
# Load the vector store
db = get_vector_store(db_path, api_key)
# Create the chat model
chat_model = ChatOpenAI(openai_api_key=api_key)
# Create the retrieval model
retriever = db.as_retriever()
# Create the chatbot
chatbot = RetrievalQA.from_llm(
llm=chat_model,
retriever=retriever,
return_source_documents=True
)
@cl.on_chat_start
async def on_chat_start():
await cl.Message(content="준비되었습니다. 메시지를 입력하세요.").send()
@cl.on_message
async def on_message(prompt: Union[str, cl.Message]):
if isinstance(prompt, cl.Message):
prompt = prompt.content
print(f"입력된 메시지: {prompt}")
result = chatbot(prompt)
answer = result["result"]
await cl.Message(content=answer).send()
# # Chat with the chatbot
# test_prompt = "유퀴즈 출연 후 아쉬웠던 점?"
# result = chatbot(test_prompt)
# print(result["query"])
# print(result["result"])
# print(result['source_documents'])