forked from SaranDharshanSP/AmritaGPT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
api.py
187 lines (147 loc) · 6.37 KB
/
api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import asyncio
import uuid
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
from langchain.vectorstores import FAISS
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts import PromptTemplate
import torch
from transformers import pipeline, AutoConfig, AutoModelForCausalLM, AutoTokenizer
import transformers
import os
from dotenv import load_dotenv
import uvicorn
from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List, Dict, Any
from huggingface_hub import InferenceClient
load_dotenv()
# Load and split the text file
with open("general.txt", encoding="utf8") as f:
raw_text = f.read()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=100,
length_function=len,
)
text_chunks = text_splitter.split_text(raw_text)
# Load embeddings for both methods
huggingface_embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
google_embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
vector_store_hf = FAISS.from_texts(text_chunks, embedding=huggingface_embeddings)
vector_store_google = FAISS.from_texts(text_chunks, embedding=google_embeddings)
vector_store_hf.save_local("faiss_index_hf")
vector_store_google.save_local("faiss_index_google")
client = InferenceClient(api_key=os.getenv("HF_API_TOKEN"))
# Generate conversational chain for both methods
def get_conversational_chain(use_google: bool):
prompt_template = """
You are an assistant for Amrita University. Use the following pieces of information to help answer the user's questions.
Always maintain context from the previous conversation and combine it with new information from the knowledge base.
Previous Conversation:
{chat_history}
Context from knowledge base:
{context}
Current Question: {question}
If you previously provided information that's relevant to the current question, use that information along with any new context.
If you cannot find specific information in the current context but you mentioned it in the chat history, you can refer to that.
If you truly don't have enough information to answer, acknowledge what you know and what you don't know.
"""
prompt = PromptTemplate(
template=prompt_template,
input_variables=["context", "question", "chat_history"]
)
if use_google:
model = ChatGoogleGenerativeAI(model="gemini-pro", temperature=0.3)
chain = load_qa_chain(model, chain_type="stuff", prompt=prompt)
return chain
# Format chat history
def format_chat_history(history):
formatted_history = ""
for entry in history:
if "user" in entry:
formatted_history += f"Human: {entry['user']}\n"
if "bot" in entry:
formatted_history += f"Assistant: {entry['bot']}\n"
return formatted_history
# Answer question
def answer_question(user_question, chat_history, use_google=False):
# Load the appropriate FAISS vector store
vector_store = FAISS.load_local(
"faiss_index_google" if use_google else "faiss_index_hf",
google_embeddings if use_google else huggingface_embeddings,
allow_dangerous_deserialization=True
)
# Create a combined query using recent context
context_query = user_question
if chat_history:
last_exchange = chat_history[-2:] # Get last user question and bot response
context_query = f"{' '.join([msg.get('user', msg.get('bot', '')) for msg in last_exchange])} {user_question}"
# Retrieve relevant documents
docs = vector_store.similarity_search(context_query)
# Extract the text content from the retrieved documents
context = "\n".join([doc.page_content for doc in docs])
# Format conversation history for context
formatted_history = format_chat_history(chat_history[-4:]) # Last 4 exchanges
if use_google:
chain = get_conversational_chain(use_google=use_google)
response = chain(
{
"input_documents": docs,
"question": user_question,
"chat_history": formatted_history
},
return_only_outputs=True
)
return response["output_text"]
else:
print(formatted_history)
# Construct messages for the Hugging Face model
messages = [
{"role": "system", "content": "Use the following context to answer the user's question."},
{"role": "system", "content": context},
{"role": "user", "content": user_question},
{"role": "system", "content": formatted_history},
]
completion = client.chat.completions.create(
model="meta-llama/Meta-Llama-3-8B-Instruct",
messages=messages,
max_tokens=500,
)
model_response = completion.choices[0].message["content"]
return model_response
conversation_context: Dict[str, Dict[str, List[Dict[str, Any]]]] = {}
class QueryRequest(BaseModel):
session_id: str | None = None
input_text: str
use_google: bool = False
app = FastAPI()
@app.get("/")
def index():
return {"message": "Hello! Use the /get-response endpoint to chat."}
@app.post("/get-response/")
def get_response(request: QueryRequest):
session_id = request.session_id if request.session_id else str(uuid.uuid4())
user_input = request.input_text
if session_id not in conversation_context:
conversation_context[session_id] = {"history": []}
conversation_history = conversation_context[session_id]["history"]
conversation_history.append({"user": user_input})
response = answer_question(user_input, conversation_history, use_google=request.use_google)
conversation_history.append({"bot": response})
conversation_context[session_id]["history"] = conversation_history
return {"session_id": session_id, "response": response, "history": conversation_history}
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
if __name__ == "__main__":
config = uvicorn.Config(app, host="127.0.0.1", port=8000, log_level="info")
server = uvicorn.Server(config)
asyncio.run(server.serve())