-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
136 lines (101 loc) · 5.24 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# pip install streamlit langchain lanchain-openai beautifulsoup4 python-dotenv chromadb
import streamlit as st
import os
from langchain_core.messages import AIMessage, HumanMessage
from langchain_community.document_loaders import WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from dotenv import load_dotenv
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
load_dotenv()
os.environ["LANGCHAIN_TRACING_V2"]="true"
os.environ["LANGCHAIN_API_KEY"]=os.getenv("LANGCHAIN_API_KEY")
os.environ["LANGCHAIN_PROJECT"]=os.getenv("LANGCHAIN_PROJECT")
def get_vectorstore_from_url(url):
# get the text in document form
loader = WebBaseLoader(url)
document = loader.load()
# split the document into chunks
text_splitter = RecursiveCharacterTextSplitter()
document_chunks = text_splitter.split_documents(document)
# create a vectorstore from the chunks
vector_store = Chroma.from_documents(document_chunks, OpenAIEmbeddings())
return vector_store
def get_context_retriever_chain(vector_store):
llm = ChatOpenAI(model_name="gpt-4o")
retriever = vector_store.as_retriever()
prompt = ChatPromptTemplate.from_messages([
MessagesPlaceholder(variable_name="chat_history"),
("user", "{input}"),
("user", "Given the above conversation, generate a search query to look up in order to get information relevant to the conversation")
])
retriever_chain = create_history_aware_retriever(llm, retriever, prompt)
return retriever_chain
def get_conversational_rag_chain(retriever_chain):
llm = ChatOpenAI(model_name="gpt-4o")
prompt = ChatPromptTemplate.from_messages([
("system", "Answer the user's questions based on the below context:\n\n{context}"),
MessagesPlaceholder(variable_name="chat_history"),
("user", "{input}"),
])
stuff_documents_chain = create_stuff_documents_chain(llm,prompt)
return create_retrieval_chain(retriever_chain, stuff_documents_chain)
def get_response(user_input):
retriever_chain = get_context_retriever_chain(st.session_state.vector_store)
conversation_rag_chain = get_conversational_rag_chain(retriever_chain)
response = conversation_rag_chain.invoke({
"chat_history": st.session_state.chat_history,
"input": user_input
})
return response['answer']
# app config
st.set_page_config(page_title="Generate MCQs from website", page_icon="🤖")
st.title("Generate MCQs from website")
# sidebar
with st.sidebar:
with st.form("user_inputs"):
st.header("Settings")
website_url = st.text_input("Website URL")
mcq_count=st.number_input("No. of MCQs", min_value=3, max_value=50)
#Subject
subject=st.text_input("Insert Subject",max_chars=20)
# Quiz Tone
tone=st.text_input("Complexity Level Of Questions", max_chars=20, placeholder="Simple, Moderate, Hard")
button=st.form_submit_button("Create MCQs")
if button and website_url is not None and mcq_count and subject and tone:
with st.spinner("loading..."):
try:
# session state
if "chat_history" not in st.session_state:
st.session_state.chat_history = None
st.session_state['chat_history'] = None
st.session_state.chat_history = [
AIMessage(content="Hello, I am a bot. How can I help you?"),
]
if "vector_store" not in st.session_state:
st.session_state.vector_store = None
st.session_state['vector_store'] = None
st.session_state.vector_store = get_vectorstore_from_url(website_url)
# user input
user_query = f"Create {mcq_count} mcq for {subject} and tone is {tone} and give me correct answer"
#st.chat_input("Create 30 mcq")
if user_query is not None and user_query != "":
response = get_response(user_query)
st.session_state.chat_history.clear()
st.session_state.chat_history.append(HumanMessage(content=user_query))
st.session_state.chat_history.append(AIMessage(content=response))
# conversation
for message in st.session_state.chat_history:
if isinstance(message, AIMessage):
with st.chat_message("AI"):
st.write(message.content)
elif isinstance(message, HumanMessage):
with st.chat_message("Human"):
st.write(message.content)
except Exception as e:
st.error("Please provide valid URL and refresh the page again")
footer = """<style>.footer {position: fixed;left: 0;bottom: 0;width: 100%;background-color: #000;color: white;text-align: center;}</style><div class='footer'><p>Copyright 2023, feel free to contact [email protected]</p></div>"""
st.markdown(footer, unsafe_allow_html=True)