Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pranshul #11

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 54 additions & 48 deletions rag_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
Answer: Let's think step by step.
"""


class RAG_Agent:
def __init__(self, model="llama3.1"):
self.model = OllamaLLM(model=model)
Expand All @@ -59,21 +58,30 @@ def setup_vectorstore(self, file_paths):
"""
all_documents = []
for file_path in file_paths:
if os.path.exists(file_path):
if file_path.endswith(".pdf"):
loader = PyMuPDFLoader(file_path)
try:
if os.path.exists(file_path):
print(f"Loading {file_path}...")
if file_path.endswith(".pdf"):
loader = PyMuPDFLoader(file_path)
else:
loader = TextLoader(file_path)

documents = loader.load()
all_documents.extend(documents)
else:
loader = TextLoader(file_path)

documents = loader.load()
all_documents.extend(documents)

# Using HuggingFace embeddings
embeddings = HuggingFaceEmbeddings()
vector_store = FAISS.from_documents(all_documents, embeddings)
retriever = vector_store.as_retriever()
return retriever

warnings.warn(f"File {file_path} not found.")
except Exception as e:
warnings.warn(f"Error processing file {file_path}: {str(e)}")

if all_documents:
embeddings = HuggingFaceEmbeddings()
vector_store = FAISS.from_documents(all_documents, embeddings)
retriever = vector_store.as_retriever()
print("Document loading complete.")
return retriever
else:
raise ValueError("No valid documents found to load.")

def get_relevant_document(self, query, threshold=0.5):
"""
Check if the query is related to a document by using the retriever.
Expand All @@ -86,51 +94,49 @@ def get_relevant_document(self, query, threshold=0.5):
Returns:
tuple: (top_result, score) if relevant document found, otherwise (None, 0.0)
"""
results = self.retriever.invoke(query)

if results:
# Check the confidence score of the top result
top_result = results[0]
score = top_result.metadata.get("score", 0.0)

if score >= threshold:
return top_result, score
return None, 0.0


try:
results = self.retriever.invoke(query)
if results:
# Check the confidence score of the top result
top_result = results[0]
score = top_result.metadata.get("score", 0.0)

if score >= threshold:
return top_result, score
return None, 0.0
except Exception as e:
warnings.warn(f"Error retrieving relevant document: {str(e)}")
return None, 0.0

def run(self):
# Format documents for context
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)

# Create the LCEL chain
# Ref. https://python.langchain.com/docs/versions/migrating_chains/retrieval_qa/
qa_chain = (
{
"context": self.retriever | format_docs,
"question": RunnablePassthrough()
}
| self.prompt
| self.model
| StrOutputParser()
)

"""
Main loop to interact with the user.
"""
print("Ask your Python or GTK-based coding questions! Type 'exit' to quit.")

while True:
question = input().strip()

if question.lower() == 'exit':
print("Exiting the assistant...")
break

doc_result, relevance_score = self.get_relevant_document(question)
# Classifying query based on it's relevance with retrieved context
if doc_result:
response = qa_chain.invoke({"query": question, "context": doc_result.page_content})
else:
response = qa_chain.invoke(question)

return response
print(response)


if __name__ == "__main__":
agent = RAG_Agent()
agent.retriever = agent.setup_vectorstore(document_paths)
agent.run()
try:
agent = RAG_Agent()
agent.retriever = agent.setup_vectorstore(document_paths)
agent.run()
except ValueError as ve:
print(f"Error: {str(ve)}")
except Exception as e:
print(f"Unexpected error: {str(e)}")