Skip to content

Commit

Permalink
Merge pull request #165 from vintasoftware/langgraph-refactoring
Browse files Browse the repository at this point in the history
Migrates to langgraph
  • Loading branch information
fjsj authored Sep 11, 2024
2 parents b9f6ec0 + 93d5f74 commit ae0dbe8
Show file tree
Hide file tree
Showing 16 changed files with 1,138 additions and 79 deletions.
1 change: 1 addition & 0 deletions django_ai_assistant/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

DEFAULTS = {
"INIT_API_FN": "django_ai_assistant.api.views.init_api",
"USE_LANGGRAPH": False,
"CAN_CREATE_THREAD_FN": "django_ai_assistant.permissions.allow_all",
"CAN_VIEW_THREAD_FN": "django_ai_assistant.permissions.owns_thread",
"CAN_UPDATE_THREAD_FN": "django_ai_assistant.permissions.owns_thread",
Expand Down
132 changes: 128 additions & 4 deletions django_ai_assistant/helpers/assistants.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
import inspect
import re
from typing import Any, ClassVar, Sequence, cast
from typing import Annotated, Any, ClassVar, Sequence, TypedDict, cast

from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad.tools import (
Expand All @@ -12,11 +12,20 @@
DEFAULT_DOCUMENT_PROMPT,
DEFAULT_DOCUMENT_SEPARATOR,
)
from langchain.tools import StructuredTool
from langchain_core.chat_history import (
BaseChatMessageHistory,
InMemoryChatMessageHistory,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AnyMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import (
ChatPromptTemplate,
Expand All @@ -37,12 +46,15 @@
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode

from django_ai_assistant.conf import app_settings
from django_ai_assistant.decorators import with_cast_id
from django_ai_assistant.exceptions import (
AIAssistantMisconfiguredError,
)
from django_ai_assistant.langchain.tools import Tool
from django_ai_assistant.langchain.tools import tool as tool_decorator


Expand Down Expand Up @@ -437,6 +449,115 @@ def get_history_aware_retriever(self) -> Runnable[dict, RetrieverOutput]:
prompt | llm | StrOutputParser() | retriever,
)

@with_cast_id
def as_graph(self, thread_id: Any | None = None) -> Runnable[dict, dict]:
"""Create the Langchain graph for the assistant.\n
This graph is an agent that supports chat history, tool calling, and RAG (if `has_rag=True`).\n
`as_graph` uses many other methods to create the graph for the assistant.
Prefer to override the other methods to customize the graph for the assistant.
Only override this method if you need to customize the graph at a lower level.
Args:
thread_id (Any | None): The thread ID for the chat message history.
If `None`, an in-memory chat message history is used.
Returns:
the compiled graph
"""
llm = self.get_llm()
tools = self.get_tools()
llm_with_tools = llm.bind_tools(tools) if tools else llm
message_history = self.get_message_history(thread_id)

def custom_add_messages(left: list[BaseMessage], right: list[BaseMessage]):
result = add_messages(left, right)

if thread_id:
messages_to_store = [
m
for m in result
if isinstance(m, HumanMessage | ChatMessage)
or (isinstance(m, AIMessage) and not m.tool_calls)
]
message_history.add_messages(messages_to_store)

return result

class AgentState(TypedDict):
messages: Annotated[list[AnyMessage], custom_add_messages]
input: str # noqa: A003
context: str
output: str

def setup(state: AgentState):
return {"messages": [SystemMessage(content=self.get_instructions())]}

def retriever(state: AgentState):
if not self.has_rag:
return

retriever = self.get_history_aware_retriever()
docs = retriever.invoke({"input": state["input"], "history": state["messages"]})

document_separator = self.get_document_separator()
document_prompt = self.get_document_prompt()

formatted_docs = document_separator.join(
format_document(doc, document_prompt) for doc in docs
)

return {
"messages": SystemMessage(
content=f"---START OF CONTEXT---\n{formatted_docs}---END OF CONTEXT---\n"
)
}

def history(state: AgentState):
history = message_history.messages if thread_id else []
return {"messages": [*history, HumanMessage(content=state["input"])]}

def agent(state: AgentState):
response = llm_with_tools.invoke(state["messages"])

return {"messages": [response]}

def tool_selector(state: AgentState):
last_message = state["messages"][-1]

if isinstance(last_message, AIMessage) and last_message.tool_calls:
return "call_tool"

return "continue"

def record_response(state: AgentState):
return {"output": state["messages"][-1].content}

workflow = StateGraph(AgentState)

workflow.add_node("setup", setup)
workflow.add_node("retriever", retriever)
workflow.add_node("history", history)
workflow.add_node("agent", agent)
workflow.add_node("tools", ToolNode(tools))
workflow.add_node("respond", record_response)

workflow.set_entry_point("setup")
workflow.add_edge("setup", "retriever")
workflow.add_edge("retriever", "history")
workflow.add_edge("history", "agent")
workflow.add_conditional_edges(
"agent",
tool_selector,
{
"call_tool": "tools",
"continue": "respond",
},
)
workflow.add_edge("tools", "agent")
workflow.add_edge("respond", END)

return workflow.compile()

@with_cast_id
def as_chain(self, thread_id: Any | None) -> Runnable[dict, dict]:
"""Create the Langchain chain for the assistant.\n
Expand Down Expand Up @@ -539,7 +660,10 @@ def invoke(self, *args: Any, thread_id: Any | None, **kwargs: Any) -> dict:
dict: The output of the assistant chain,
structured like `{"output": "assistant response", "history": ...}`.
"""
chain = self.as_chain(thread_id)
if app_settings.USE_LANGGRAPH:
chain = self.as_graph(thread_id)
else:
chain = self.as_chain(thread_id)
return chain.invoke(*args, **kwargs)

@with_cast_id
Expand Down Expand Up @@ -577,7 +701,7 @@ def as_tool(self, description: str) -> BaseTool:
Returns:
BaseTool: A tool that runs the assistant. The tool name is this assistant's id.
"""
return Tool.from_function(
return StructuredTool.from_function(
func=self._run_as_tool,
name=self.id,
description=description,
Expand Down
26 changes: 21 additions & 5 deletions django_ai_assistant/langchain/chat_message_histories.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from django.db import transaction

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
)

from django_ai_assistant.decorators import with_cast_id
from django_ai_assistant.models import Message
Expand Down Expand Up @@ -73,13 +77,19 @@ def add_messages(self, messages: Sequence[BaseMessage]) -> None:
messages: A list of BaseMessage objects to store.
"""
with transaction.atomic():
existing_message_ids = [
str(i) for i in self._get_messages_qs().values_list("id", flat=True)
]

messages_to_create = [m for m in messages if m.id not in existing_message_ids]

created_messages = Message.objects.bulk_create(
[Message(thread_id=self._thread_id, message=dict()) for message in messages]
[Message(thread_id=self._thread_id, message=dict()) for _ in messages_to_create]
)

# Update langchain message IDs with Django message IDs
for idx, created_message in enumerate(created_messages):
message_with_id = messages[idx]
message_with_id = messages_to_create[idx]
message_with_id.id = str(created_message.id)
created_message.message = message_to_dict(message_with_id)

Expand All @@ -91,15 +101,21 @@ async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
Args:
messages: A list of BaseMessage objects to store.
"""
existing_message_ids = [
str(i) async for i in self._get_messages_qs().values_list("id", flat=True)
]

messages_to_create = [m for m in messages if m.id not in existing_message_ids]

# NOTE: This method does not use transactions because it do not yet work in async mode.
# Source: https://docs.djangoproject.com/en/5.0/topics/async/#queries-the-orm
created_messages = await Message.objects.abulk_create(
[Message(thread_id=self._thread_id, message=dict()) for message in messages]
[Message(thread_id=self._thread_id, message=dict()) for _ in messages_to_create]
)

# Update langchain message IDs with Django message IDs
for idx, created_message in enumerate(created_messages):
message_with_id = messages[idx]
message_with_id = messages_to_create[idx]
message_with_id.id = str(created_message.id)
created_message.message = message_to_dict(message_with_id)

Expand Down
1 change: 1 addition & 0 deletions example/example/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@

# django-ai-assistant

AI_ASSISTANT_USE_LANGGRAPH = True
AI_ASSISTANT_INIT_API_FN = "django_ai_assistant.api.views.init_api"
AI_ASSISTANT_CAN_CREATE_THREAD_FN = "django_ai_assistant.permissions.allow_all"
AI_ASSISTANT_CAN_VIEW_THREAD_FN = "django_ai_assistant.permissions.owns_thread"
Expand Down
34 changes: 18 additions & 16 deletions example/tour_guide/ai_assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,17 @@ class TourGuideAIAssistant(AIAssistant):
name = "Tour Guide Assistant"
instructions = (
"You are a tour guide assistant that offers information about nearby attractions. "
"The application will capture the user coordinates, and should provide a list of nearby attractions. "
"Use the available tools to suggest nearby attractions to the user. "
"You don't need to include all the found items, only include attractions that are relevant for a tourist. "
"Select the top 10 best attractions for a tourist, if there are less then 10 relevant items only return these. "
"Order items by the most relevant to the least relevant. "
"If there are no relevant attractions nearby, just keep the list empty. "
"Your response will be integrated with a frontend web application therefore it's critical that "
"it only contains a valid JSON. DON'T include '```json' in your response. "
"You will receive the user coordinates and should use available tools to find nearby attractions. "
"Only call the find_nearby_attractions tool once. "
"Your response should only contain valid JSON data. DON'T include '```json' in your response. "
"The JSON should be formatted according to the following structure: \n"
f"\n\n{_tour_guide_example_json()}\n\n\n"
"In the 'attraction_name' field provide the name of the attraction in english. "
"In the 'attraction_description' field generate an overview about the attraction with the most important information, "
"curiosities and interesting facts. "
"Only include a value for the 'attraction_url' field if you find a real value in the provided data otherwise keep it empty. "
)
model = "gpt-4o"
model = "gpt-4o-2024-08-06"

def get_instructions(self):
# Warning: this will use the server's timezone
Expand All @@ -60,11 +55,18 @@ def get_instructions(self):
return f"Today is: {current_date_str}. {self.instructions}"

@method_tool
def get_nearby_attractions_from_api(self, latitude: float, longitude: float) -> dict:
"""Find nearby attractions based on user's current location."""
return fetch_points_of_interest(
latitude=latitude,
longitude=longitude,
tags=["tourism", "leisure", "place", "building"],
radius=500,
def find_nearby_attractions(self, latitude: float, longitude: float) -> str:
"""
Find nearby attractions based on user's current location.
Returns a JSON with the list of all types of points of interest,
which may or may not include attractions.
Calls to this tool are idempotent.
"""
return json.dumps(
fetch_points_of_interest(
latitude=latitude,
longitude=longitude,
tags=["tourism", "leisure", "place", "building"],
radius=500,
)
)
Loading

0 comments on commit ae0dbe8

Please sign in to comment.