Skip to content

Commit

Permalink
Add custom retriever (#2013)
Browse files Browse the repository at this point in the history
## Description

- Custom retriever example
- Add search knowledge base tool if retriever is provided

---------

Co-authored-by: Dirk Brand <[email protected]>
  • Loading branch information
ashpreetbedi and dirkbrnd authored Feb 5, 2025
1 parent 32d0b31 commit 9068a54
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 5 deletions.
Empty file.
80 changes: 80 additions & 0 deletions cookbook/agent_concepts/knowledge/custom/retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Optional

from agno.agent import Agent
from agno.embedder.openai import OpenAIEmbedder
from agno.knowledge.pdf_url import PDFUrlKnowledgeBase
from agno.vectordb.qdrant import Qdrant
from qdrant_client import QdrantClient

# ---------------------------------------------------------
# This section loads the knowledge base. Skip if your knowledge base was populated elsewhere.
# Define the embedder
embedder = OpenAIEmbedder(id="text-embedding-3-small")
# Initialize vector database connection
vector_db = Qdrant(collection="thai-recipes", path="tmp/qdrant", embedder=embedder)
# Load the knowledge base
knowledge_base = PDFUrlKnowledgeBase(
urls=["https://agno-public.s3.amazonaws.com/recipes/ThaiRecipes.pdf"],
vector_db=vector_db,
)

# Load the knowledge base
# knowledge_base.load(recreate=True) # Comment out after first run
# Knowledge base is now loaded
# ---------------------------------------------------------


# Define the custom retriever
# This is the function that the agent will use to retrieve documents
def retriever(
query: str, agent: Optional[Agent] = None, num_documents: int = 5, **kwargs
) -> Optional[list[dict]]:
"""
Custom retriever function to search the vector database for relevant documents.
Args:
query (str): The search query string
agent (Agent): The agent instance making the query
num_documents (int): Number of documents to retrieve (default: 5)
**kwargs: Additional keyword arguments
Returns:
Optional[list[dict]]: List of retrieved documents or None if search fails
"""
try:
qdrant_client = QdrantClient(path="tmp/qdrant")
query_embedding = embedder.get_embedding(query)
results = qdrant_client.query_points(
collection_name="thai-recipes",
query=query_embedding,
limit=num_documents,
)
results_dict = results.model_dump()
if "points" in results_dict:
return results_dict["points"]
else:
return None
except Exception as e:
print(f"Error during vector database search: {str(e)}")
return None


def main():
"""Main function to demonstrate agent usage."""
# Initialize agent with custom retriever
# Remember to set search_knowledge=True to use agentic_rag or add_reference=True for traditional RAG
# search_knowledge=True is default when you add a knowledge base but is needed here
agent = Agent(
retriever=retriever,
search_knowledge=True,
instructions="Search the knowledge base for information",
show_tool_calls=True,
)

# Example query
query = "List down the ingredients to make Massaman Gai"
agent.print_response(query, markdown=True)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
search_type=SearchType.hybrid,
vector_index=VectorIndex.HNSW,
distance=Distance.COSINE,
local=True, # Set to False if using Weaviate Cloud and True if using local instance
local=True, # Set to False if using Weaviate Cloud and True if using local instance
)
# Create knowledge base
knowledge_base = PDFUrlKnowledgeBase(
Expand Down
18 changes: 14 additions & 4 deletions libs/agno/agno/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,7 +1326,7 @@ def get_tools(self) -> Optional[List[Union[Toolkit, Callable, Dict, Function]]]:
tools.append(self.update_memory)

# Add tools for accessing knowledge
if self.knowledge is not None:
if self.knowledge is not None or self.retriever is not None:
if self.search_knowledge:
tools.append(self.search_knowledge_base)
if self.update_knowledge:
Expand Down Expand Up @@ -2377,9 +2377,19 @@ def get_relevant_docs_from_knowledge(
"""Return a list of references from the knowledge base"""
from agno.document import Document

if self.retriever is not None:
retriever_kwargs = {"agent": self, "query": query, "num_documents": num_documents, **kwargs}
return self.retriever(**retriever_kwargs)
if self.retriever is not None and callable(self.retriever):
from inspect import signature

try:
sig = signature(self.retriever)
retriever_kwargs: Dict[str, Any] = {}
if "agent" in sig.parameters:
retriever_kwargs = {"agent": self}
retriever_kwargs.update({"query": query, "num_documents": num_documents, **kwargs})
return self.retriever(**retriever_kwargs)
except Exception as e:
logger.warning(f"Retriever failed: {e}")
return None

if self.knowledge is None:
return None
Expand Down

0 comments on commit 9068a54

Please sign in to comment.