Skip to content

Commit

Permalink
reorganizes graph so the retriver node is self contained
Browse files Browse the repository at this point in the history
  • Loading branch information
filipeximenes committed Sep 10, 2024
1 parent 6de74c2 commit 59c8950
Showing 1 changed file with 23 additions and 22 deletions.
45 changes: 23 additions & 22 deletions django_ai_assistant/helpers/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,20 @@
DEFAULT_DOCUMENT_PROMPT,
DEFAULT_DOCUMENT_SEPARATOR,
)
from langchain.tools import StructuredTool
from langchain_core.chat_history import (
BaseChatMessageHistory,
InMemoryChatMessageHistory,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, ChatMessage, HumanMessage
from langchain_core.messages import (
AIMessage,
AnyMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import (
ChatPromptTemplate,
Expand Down Expand Up @@ -46,7 +54,6 @@
from django_ai_assistant.exceptions import (
AIAssistantMisconfiguredError,
)
from django_ai_assistant.langchain.tools import Tool
from django_ai_assistant.langchain.tools import tool as tool_decorator


Expand Down Expand Up @@ -471,7 +478,7 @@ class AgentState(TypedDict):
llm_with_tools = llm.bind_tools(tools) if tools else llm

def setup(state: AgentState):
return {"messages": [*message_history.messages, HumanMessage(content=state["input"])]}
return {"messages": [SystemMessage(content=self.get_instructions())]}

def retriever(state: AgentState):
if not self.has_rag:
Expand All @@ -487,25 +494,17 @@ def retriever(state: AgentState):
format_document(doc, document_prompt) for doc in docs
)

return {"context": formatted_docs}

def agent(state: AgentState):
prompt_template = ChatPromptTemplate.from_messages(
[
("system", self.get_instructions()),
MessagesPlaceholder(variable_name="history"),
]
)

prompt_variables: dict[str, Any] = {"history": state["messages"]}

if self.has_rag:
context_placeholder = self.get_context_placeholder()
prompt_variables[context_placeholder] = state.get("context", "")
return {
"messages": SystemMessage(
content=f"---START OF CONTEXT---\n{formatted_docs}---END OF CONTEXT---\n"
)
}

prompt = prompt_template.format(**prompt_variables)
def history(state: AgentState):
return {"messages": [*message_history.messages, HumanMessage(content=state["input"])]}

response = llm_with_tools.invoke(prompt)
def agent(state: AgentState):
response = llm_with_tools.invoke(state["messages"])

return {"messages": [response]}

Expand All @@ -524,13 +523,15 @@ def record_response(state: AgentState):

workflow.add_node("setup", setup)
workflow.add_node("retriever", retriever)
workflow.add_node("history", history)
workflow.add_node("agent", agent)
workflow.add_node("tools", ToolNode(tools))
workflow.add_node("respond", record_response)

workflow.set_entry_point("setup")
workflow.add_edge("setup", "retriever")
workflow.add_edge("retriever", "agent")
workflow.add_edge("retriever", "history")
workflow.add_edge("history", "agent")
workflow.add_conditional_edges(
"agent",
tool_selector,
Expand Down Expand Up @@ -685,7 +686,7 @@ def as_tool(self, description: str) -> BaseTool:
Returns:
BaseTool: A tool that runs the assistant. The tool name is this assistant's id.
"""
return Tool.from_function(
return StructuredTool.from_function(
func=self._run_as_tool,
name=self.id,
description=description,
Expand Down

0 comments on commit 59c8950

Please sign in to comment.