Skip to content

Commit

Permalink
only stores messages that are from humam or from the AI as long as it…
Browse files Browse the repository at this point in the history
…'s not a tool call
  • Loading branch information
filipeximenes committed Sep 10, 2024
1 parent ad4320f commit 79d6d13
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
20 changes: 13 additions & 7 deletions django_ai_assistant/helpers/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
InMemoryChatMessageHistory,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, HumanMessage
from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, ChatMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import (
ChatPromptTemplate,
Expand Down Expand Up @@ -446,13 +446,19 @@ def as_graph(self, thread_id: Any | None = None) -> Runnable[dict, dict]:
message_history = self.get_message_history(thread_id)

def custom_add_messages(left: list[BaseMessage], right: list[BaseMessage]):
if thread_id is None:
return add_messages(left, right)

message_history.add_messages(right)
messages = message_history.messages
result = add_messages(left, right)

if thread_id:
# We only want to store human and ai messages that are not tool calls
messages_to_store = [
m
for m in result
if isinstance(m, HumanMessage | ChatMessage)
or (isinstance(m, AIMessage) and not m.tool_calls)
]
message_history.add_messages(messages_to_store)

return messages
return result

class AgentState(TypedDict):
messages: Annotated[list[AnyMessage], custom_add_messages]
Expand Down
8 changes: 6 additions & 2 deletions django_ai_assistant/langchain/chat_message_histories.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,19 @@ async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
Args:
messages: A list of BaseMessage objects to store.
"""
existing_messages = self.messages

messages_to_create = [m for m in messages if m not in existing_messages]

# NOTE: This method does not use transactions because it do not yet work in async mode.
# Source: https://docs.djangoproject.com/en/5.0/topics/async/#queries-the-orm
created_messages = await Message.objects.abulk_create(
[Message(thread_id=self._thread_id, message=dict()) for message in messages]
[Message(thread_id=self._thread_id, message=dict()) for _ in messages_to_create]
)

# Update langchain message IDs with Django message IDs
for idx, created_message in enumerate(created_messages):
message_with_id = messages[idx]
message_with_id = messages_to_create[idx]
message_with_id.id = str(created_message.id)
created_message.message = message_to_dict(message_with_id)

Expand Down

0 comments on commit 79d6d13

Please sign in to comment.