Skip to content

Commit

Permalink
Merge pull request #180 from vintasoftware/feat/store-all-messages
Browse files Browse the repository at this point in the history
Store all messages (include tool calls)
  • Loading branch information
fjsj authored Oct 4, 2024
2 parents bdc1841 + c350737 commit ba3d4ec
Show file tree
Hide file tree
Showing 14 changed files with 394 additions and 1,737 deletions.
7 changes: 7 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,10 @@ omit =
*/migrations/*
*static*
*/templates/*
management/commands/generate_openapi_schema.py
apps.py
admin.py
exclude_lines =
pragma: no cover
if TYPE_CHECKING:
raise NotImplementedError
50 changes: 22 additions & 28 deletions django_ai_assistant/helpers/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
AIMessage,
AnyMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
Expand All @@ -34,15 +33,15 @@
)
from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.graph import END, StateGraph, add_messages
from langgraph.prebuilt import ToolNode
from pydantic import BaseModel

from django_ai_assistant.decorators import with_cast_id
from django_ai_assistant.exceptions import (
AIAssistantMisconfiguredError,
)
from django_ai_assistant.helpers.django_messages import save_django_messages
from django_ai_assistant.langchain.tools import tool as tool_decorator


Expand Down Expand Up @@ -417,46 +416,39 @@ def as_graph(self, thread_id: Any | None = None) -> Runnable[dict, dict]:
Returns:
the compiled graph
"""
# DjangoChatMessageHistory must be here because Django may not be loaded yet elsewhere.
# DjangoChatMessageHistory was used in the context of langchain, now that we are using
# langgraph this can be further simplified by just porting the add_messages logic.
from django_ai_assistant.langchain.chat_message_histories import (
DjangoChatMessageHistory,
)

message_history = DjangoChatMessageHistory(thread_id) if thread_id else None
from django_ai_assistant.models import Thread

llm = self.get_llm()
tools = self.get_tools()
llm_with_tools = llm.bind_tools(tools) if tools else llm
if thread_id:
thread = Thread.objects.get(id=thread_id)
else:
thread = None

def custom_add_messages(left: list[BaseMessage], right: list[BaseMessage]):
result = add_messages(left, right) # type: ignore

if message_history:
messages_to_store = [
m
for m in result
if isinstance(m, HumanMessage | ChatMessage)
or (isinstance(m, AIMessage) and not m.tool_calls)
]
message_history.add_messages(messages_to_store)

if thread:
# Save all messages, except the initial system message:
thread_messages = [m for m in result if not isinstance(m, SystemMessage)]
save_django_messages(cast(list[BaseMessage], thread_messages), thread=thread)
return result

class AgentState(TypedDict):
messages: Annotated[list[AnyMessage], custom_add_messages]
input: str # noqa: A003
context: str
input: str | None # noqa: A003
output: Any

def setup(state: AgentState):
system_prompt = self.get_instructions()
return {"messages": [SystemMessage(content=system_prompt)]}

def history(state: AgentState):
history = message_history.messages if message_history else []
return {"messages": [*history, HumanMessage(content=state["input"])]}
messages = thread.get_messages(include_extra_messages=True) if thread else []
if state["input"]:
messages.append(HumanMessage(content=state["input"]))

return {"messages": messages}

def retriever(state: AgentState):
if not self.has_rag:
Expand All @@ -465,8 +457,9 @@ def retriever(state: AgentState):
retriever = self.get_history_aware_retriever()
# Remove the initial instructions to prevent having two SystemMessages
# This is necessary for compatibility with Anthropic
messages_without_input = state["messages"][1:-1]
docs = retriever.invoke({"input": state["input"], "history": messages_without_input})
messages_to_summarize = state["messages"][1:-1]
input_message = state["messages"][-1]
docs = retriever.invoke({"input": input_message, "history": messages_to_summarize})

document_separator = self.get_document_separator()
document_prompt = self.get_document_prompt()
Expand Down Expand Up @@ -550,7 +543,8 @@ def invoke(self, *args: Any, thread_id: Any | None, **kwargs: Any) -> dict:
Args:
*args: Positional arguments to pass to the graph.
Make sure to include a `dict` like `{"input": "user message"}`.
To add a new message, use a dict like `{"input": "user message"}`.
If thread already has a `HumanMessage` in the end, you can invoke without args.
thread_id (Any | None): The thread ID for the chat message history.
If `None`, an in-memory chat message history is used.
**kwargs: Keyword arguments to pass to the graph.
Expand Down
49 changes: 49 additions & 0 deletions django_ai_assistant/helpers/django_messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import TYPE_CHECKING

from django.db import transaction

from langchain_core.messages import (
BaseMessage,
message_to_dict,
)


if TYPE_CHECKING:
from django_ai_assistant.models import Message as DjangoMessage
from django_ai_assistant.models import Thread


@transaction.atomic
def save_django_messages(messages: list[BaseMessage], thread: "Thread") -> list["DjangoMessage"]:
"""
Save a list of messages to the Django database.
Note: Changes the message objects in place by changing each message.id to the Django ID.
Args:
messages (list[BaseMessage]): The list of messages to save.
thread (Thread): The thread to save the messages to.
"""

from django_ai_assistant.models import Message as DjangoMessage

existing_message_ids = [
str(i)
for i in DjangoMessage.objects.filter(thread=thread)
.order_by("created_at")
.values_list("id", flat=True)
]

messages_to_create = [m for m in messages if m.id not in existing_message_ids]

created_messages = DjangoMessage.objects.bulk_create(
[DjangoMessage(thread=thread, message={}) for _ in messages_to_create]
)

# Update langchain message IDs with Django message IDs
for idx, created_message in enumerate(created_messages):
message_with_id = messages_to_create[idx]
message_with_id.id = str(created_message.id)
created_message.message = message_to_dict(message_with_id)

DjangoMessage.objects.bulk_update(created_messages, ["message"])
return created_messages
7 changes: 2 additions & 5 deletions django_ai_assistant/helpers/use_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
AIUserNotAllowedError,
)
from django_ai_assistant.helpers.assistants import AIAssistant
from django_ai_assistant.langchain.chat_message_histories import DjangoChatMessageHistory
from django_ai_assistant.models import Message, Thread
from django_ai_assistant.permissions import (
can_create_message,
Expand Down Expand Up @@ -291,7 +290,7 @@ def get_thread_messages(
if user != thread.created_by:
raise AIUserNotAllowedError("User is not allowed to view messages in this thread")

return DjangoChatMessageHistory(thread.id).get_messages()
return thread.get_messages(include_extra_messages=False)


def delete_message(
Expand All @@ -312,6 +311,4 @@ def delete_message(
if not can_delete_message(message=message, user=user, request=request):
raise AIUserNotAllowedError("User is not allowed to delete this message")

return DjangoChatMessageHistory(thread_id=message.thread_id).remove_messages(
message_ids=[str(message.id)]
)
return message.delete()
175 changes: 0 additions & 175 deletions django_ai_assistant/langchain/chat_message_histories.py

This file was deleted.

Loading

0 comments on commit ba3d4ec

Please sign in to comment.