Skip to content

Commit

Permalink
adds as_graph method to the ai assistant class to port
Browse files Browse the repository at this point in the history
  • Loading branch information
filipeximenes committed Sep 5, 2024
1 parent b9f6ec0 commit 2e5377d
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 4 deletions.
65 changes: 65 additions & 0 deletions django_ai_assistant/helpers/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI
from langgraph.graph import MessagesState

from django_ai_assistant.decorators import with_cast_id
from django_ai_assistant.exceptions import (
Expand Down Expand Up @@ -437,6 +438,70 @@ def get_history_aware_retriever(self) -> Runnable[dict, RetrieverOutput]:
prompt | llm | StrOutputParser() | retriever,
)

@with_cast_id
def as_graph(self, thread_id: Any | None = None) -> Runnable[dict, dict]:
from langchain_core.messages import AIMessage
from langgraph.graph import END, StateGraph
from langgraph.prebuilt import ToolNode

class AgentState(MessagesState):
response: str

llm = self.get_llm()
tools = self.get_tools()
if tools:
llm_with_tools = llm.bind_tools(tools)
else:
llm_with_tools = llm

def agent(state: AgentState):
prompt_template = ChatPromptTemplate.from_messages(
[
("system", self.instructions),
MessagesPlaceholder(variable_name="history"),
]
)
message_history = self.get_message_history(thread_id)
prompt = prompt_template.format(
history=message_history.messages + state["messages"],
)

response = llm_with_tools.invoke(prompt)

return {"messages": [response]}

def tool_selector(state: AgentState):
messages = state["messages"]
last_message = messages[-1]

if isinstance(last_message, AIMessage) and len(last_message.tool_calls) > 0:
return "call_tool"

return "continue"

def record_response(state: AgentState):
return {"response": state["messages"][-1].content}

workflow = StateGraph(AgentState)

workflow.add_node("agent", agent)
workflow.add_node("tools", ToolNode(tools))
workflow.add_node("respond", record_response)

workflow.set_entry_point("agent")
workflow.add_edge("tools", "agent")
workflow.add_conditional_edges(
"agent",
tool_selector,
{
"call_tool": "tools",
"continue": "respond",
},
)
workflow.add_edge("respond", END)

return workflow.compile()

@with_cast_id
def as_chain(self, thread_id: Any | None) -> Runnable[dict, dict]:
"""Create the Langchain chain for the assistant.\n
Expand Down
42 changes: 38 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ pydantic = "^2.7.1"
django-ninja = "^1.1.0"
langchain = "^0.2.1"
langchain-openai = "^0.1.8"
langgraph = "^0.2.16"

[tool.poetry.group.dev.dependencies]
coverage = "^7.2.7"
Expand Down

0 comments on commit 2e5377d

Please sign in to comment.