Skip to content

Commit

Permalink
adds rag capability to as_graph flow
Browse files Browse the repository at this point in the history
  • Loading branch information
filipeximenes committed Sep 6, 2024
1 parent 2e5377d commit e356528
Showing 1 changed file with 38 additions and 17 deletions.
55 changes: 38 additions & 17 deletions django_ai_assistant/helpers/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
InMemoryChatMessageHistory,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import (
ChatPromptTemplate,
Expand All @@ -37,7 +38,8 @@
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI
from langgraph.graph import MessagesState
from langgraph.graph import END, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode

from django_ai_assistant.decorators import with_cast_id
from django_ai_assistant.exceptions import (
Expand Down Expand Up @@ -440,19 +442,30 @@ def get_history_aware_retriever(self) -> Runnable[dict, RetrieverOutput]:

@with_cast_id
def as_graph(self, thread_id: Any | None = None) -> Runnable[dict, dict]:
from langchain_core.messages import AIMessage
from langgraph.graph import END, StateGraph
from langgraph.prebuilt import ToolNode

class AgentState(MessagesState):
response: str
input: str # noqa: A003
context: str
output: str

llm = self.get_llm()
tools = self.get_tools()
if tools:
llm_with_tools = llm.bind_tools(tools)
else:
llm_with_tools = llm
llm_with_tools = llm.bind_tools(tools) if tools else llm

def retriever(state: AgentState):
if not self.has_rag:
return

retriever = self.get_history_aware_retriever()
docs = retriever.invoke({"input": state["input"], "history": state["messages"][:-1]})

document_separator = self.get_document_separator()
document_prompt = self.get_document_prompt()

formatted_docs = document_separator.join(
format_document(doc, document_prompt) for doc in docs
)

return {"messages": [state["input"]], "context": formatted_docs}

def agent(state: AgentState):
prompt_template = ChatPromptTemplate.from_messages(
Expand All @@ -461,11 +474,17 @@ def agent(state: AgentState):
MessagesPlaceholder(variable_name="history"),
]
)

message_history = self.get_message_history(thread_id)
prompt = prompt_template.format(
history=message_history.messages + state["messages"],
)
prompt_variables: dict[str, Any] = {
"history": message_history.messages + state["messages"]
}

if self.has_rag:
context_placeholder = self.get_context_placeholder()
prompt_variables[context_placeholder] = state.get("context", "")

prompt = prompt_template.format(**prompt_variables)
response = llm_with_tools.invoke(prompt)

return {"messages": [response]}
Expand All @@ -474,22 +493,23 @@ def tool_selector(state: AgentState):
messages = state["messages"]
last_message = messages[-1]

if isinstance(last_message, AIMessage) and len(last_message.tool_calls) > 0:
if isinstance(last_message, AIMessage) and last_message.tool_calls:
return "call_tool"

return "continue"

def record_response(state: AgentState):
return {"response": state["messages"][-1].content}
return {"output": state["messages"][-1].content}

workflow = StateGraph(AgentState)

workflow.add_node("retriever", retriever)
workflow.add_node("agent", agent)
workflow.add_node("tools", ToolNode(tools))
workflow.add_node("respond", record_response)

workflow.set_entry_point("agent")
workflow.add_edge("tools", "agent")
workflow.set_entry_point("retriever")
workflow.add_edge("retriever", "agent")
workflow.add_conditional_edges(
"agent",
tool_selector,
Expand All @@ -498,6 +518,7 @@ def record_response(state: AgentState):
"continue": "respond",
},
)
workflow.add_edge("tools", "agent")
workflow.add_edge("respond", END)

return workflow.compile()
Expand Down

0 comments on commit e356528

Please sign in to comment.