Skip to content

Commit

Permalink
Merge pull request #167 from vintasoftware/remove-aschain
Browse files Browse the repository at this point in the history
Removes as_chain
  • Loading branch information
fjsj authored Sep 11, 2024
2 parents ae0dbe8 + 2fe45a7 commit 81d0ff6
Show file tree
Hide file tree
Showing 9 changed files with 771 additions and 206 deletions.
1 change: 0 additions & 1 deletion django_ai_assistant/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

DEFAULTS = {
"INIT_API_FN": "django_ai_assistant.api.views.init_api",
"USE_LANGGRAPH": False,
"CAN_CREATE_THREAD_FN": "django_ai_assistant.permissions.allow_all",
"CAN_VIEW_THREAD_FN": "django_ai_assistant.permissions.owns_thread",
"CAN_UPDATE_THREAD_FN": "django_ai_assistant.permissions.owns_thread",
Expand Down
194 changes: 19 additions & 175 deletions django_ai_assistant/helpers/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,11 @@
import re
from typing import Annotated, Any, ClassVar, Sequence, TypedDict, cast

from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad.tools import (
format_to_tool_messages,
)
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
from langchain.chains.combine_documents.base import (
DEFAULT_DOCUMENT_PROMPT,
DEFAULT_DOCUMENT_SEPARATOR,
)
from langchain.tools import StructuredTool
from langchain_core.chat_history import (
BaseChatMessageHistory,
InMemoryChatMessageHistory,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
Expand All @@ -38,19 +29,15 @@
RetrieverOutput,
)
from langchain_core.runnables import (
ConfigurableFieldSpec,
Runnable,
RunnableBranch,
RunnablePassthrough,
)
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode

from django_ai_assistant.conf import app_settings
from django_ai_assistant.decorators import with_cast_id
from django_ai_assistant.exceptions import (
AIAssistantMisconfiguredError,
Expand Down Expand Up @@ -90,8 +77,6 @@ class AIAssistant(abc.ABC): # noqa: F821
"""Whether the assistant uses RAG (Retrieval-Augmented Generation) or not.\n
Defaults to `False`.
When True, the assistant will use a retriever to get documents to provide as context to the LLM.
For this to work, the `instructions` should contain a placeholder for the context,
which is `{context}` by default.
Additionally, the assistant class should implement the `get_retriever` method to return
the retriever to use."""
_user: Any | None
Expand Down Expand Up @@ -266,58 +251,6 @@ def get_model_kwargs(self) -> dict[str, Any]:
"""
return {}

def get_prompt_template(self) -> ChatPromptTemplate:
"""Get the `ChatPromptTemplate` for the Langchain chain to use.\n
The system prompt comes from the `get_instructions` method.\n
The template includes placeholders for the instructions, chat `{history}`, user `{input}`,
and `{agent_scratchpad}`, all which are necessary for the chain to work properly.\n
The chat history is filled by the chain using the message history from `get_message_history`.\n
If the assistant uses RAG, the instructions should contain a placeholder
for the context, which is `{context}` by default, defined by the `get_context_placeholder` method.
Returns:
ChatPromptTemplate: The chat prompt template for the Langchain chain.
"""
instructions = self.get_instructions()
context_placeholder = self.get_context_placeholder()
if self.has_rag and f"{context_placeholder}" not in instructions:
raise AIAssistantMisconfiguredError(
f"{self.__class__.__name__} has_rag=True"
f"but does not have a {{{context_placeholder}}} placeholder in instructions."
)

return ChatPromptTemplate.from_messages(
[
("system", instructions),
MessagesPlaceholder(variable_name="history"),
("human", "{input}"),
("placeholder", "{agent_scratchpad}"),
]
)

@with_cast_id
def get_message_history(self, thread_id: Any | None) -> BaseChatMessageHistory:
"""Get the chat message history instance for the given `thread_id`.\n
The Langchain chain uses the return of this method to get the thread messages
for the assistant, filling the `history` placeholder in the `get_prompt_template`.\n
Args:
thread_id (Any | None): The thread ID for the chat message history.
If `None`, an in-memory chat message history is used.
Returns:
BaseChatMessageHistory: The chat message history instance for the given `thread_id`.
"""

# DjangoChatMessageHistory must be here because Django may not be loaded yet elsewhere:
from django_ai_assistant.langchain.chat_message_histories import (
DjangoChatMessageHistory,
)

if thread_id is None:
return InMemoryChatMessageHistory()
return DjangoChatMessageHistory(thread_id)

def get_llm(self) -> BaseChatModel:
"""Get the Langchain LLM instance for the assistant.
By default, this uses the OpenAI implementation.\n
Expand Down Expand Up @@ -368,15 +301,6 @@ def get_document_prompt(self) -> PromptTemplate:
"""
return DEFAULT_DOCUMENT_PROMPT

def get_context_placeholder(self) -> str:
"""Get the RAG context placeholder to use in the prompt when `has_rag=True`.\n
Defaults to `"context"`. Override this method to use a different placeholder.
Returns:
str: the RAG context placeholder to use in the prompt.
"""
return "context"

def get_retriever(self) -> BaseRetriever:
"""Get the RAG retriever to use for fetching documents.\n
Must be implemented by subclasses when `has_rag=True`.\n
Expand Down Expand Up @@ -464,15 +388,23 @@ def as_graph(self, thread_id: Any | None = None) -> Runnable[dict, dict]:
Returns:
the compiled graph
"""
# DjangoChatMessageHistory must be here because Django may not be loaded yet elsewhere.
# DjangoChatMessageHistory was used in the context of langchain, now that we are using
# langgraph this can be further simplified by just porting the add_messages logic.
from django_ai_assistant.langchain.chat_message_histories import (
DjangoChatMessageHistory,
)

message_history = DjangoChatMessageHistory(thread_id) if thread_id else None

llm = self.get_llm()
tools = self.get_tools()
llm_with_tools = llm.bind_tools(tools) if tools else llm
message_history = self.get_message_history(thread_id)

def custom_add_messages(left: list[BaseMessage], right: list[BaseMessage]):
result = add_messages(left, right)

if thread_id:
if message_history:
messages_to_store = [
m
for m in result
Expand Down Expand Up @@ -513,7 +445,7 @@ def retriever(state: AgentState):
}

def history(state: AgentState):
history = message_history.messages if thread_id else []
history = message_history.messages if message_history else []
return {"messages": [*history, HumanMessage(content=state["input"])]}

def agent(state: AgentState):
Expand Down Expand Up @@ -558,113 +490,25 @@ def record_response(state: AgentState):

return workflow.compile()

@with_cast_id
def as_chain(self, thread_id: Any | None) -> Runnable[dict, dict]:
"""Create the Langchain chain for the assistant.\n
This chain is an agent that supports chat history, tool calling, and RAG (if `has_rag=True`).\n
`as_chain` uses many other methods to create the chain.\n
Prefer to override the other methods to customize the chain for the assistant.
Only override this method if you need to customize the chain at a lower level.
The chain input is a dictionary with the key `"input"` containing the user message.\n
The chain output is a dictionary with the key `"output"` containing the assistant response,
along with the key `"history"` containing the previous chat history.
Args:
thread_id (Any | None): The thread ID for the chat message history.
If `None`, an in-memory chat message history is used.
Returns:
Runnable[dict, dict]: The Langchain chain for the assistant.
"""
# Based on:
# - https://python.langchain.com/v0.2/docs/how_to/qa_chat_history_how_to/
# - https://python.langchain.com/v0.2/docs/how_to/migrate_agent/
# TODO: use langgraph instead?
llm = self.get_llm()
tools = self.get_tools()
prompt = self.get_prompt_template()
tools = cast(Sequence[BaseTool], tools)
if tools:
llm_with_tools = llm.bind_tools(tools)
else:
llm_with_tools = llm
chain = (
# based on create_tool_calling_agent:
RunnablePassthrough.assign(
agent_scratchpad=lambda x: format_to_tool_messages(x["intermediate_steps"])
).with_config(run_name="format_to_tool_messages")
)

if self.has_rag:
# based on create_retrieval_chain:
retriever = self.get_history_aware_retriever()
chain = chain | RunnablePassthrough.assign(
docs=retriever.with_config(run_name="retrieve_documents"),
)

# based on create_stuff_documents_chain:
document_separator = self.get_document_separator()
document_prompt = self.get_document_prompt()
context_placeholder = self.get_context_placeholder()
chain = chain | RunnablePassthrough.assign(
**{
context_placeholder: lambda x: document_separator.join(
format_document(doc, document_prompt) for doc in x["docs"]
)
}
).with_config(run_name="format_input_docs")

chain = chain | prompt | llm_with_tools | ToolsAgentOutputParser()

agent_executor = AgentExecutor(
agent=chain, # pyright: ignore[reportArgumentType]
tools=tools,
)
agent_with_chat_history = RunnableWithMessageHistory(
agent_executor, # pyright: ignore[reportArgumentType]
get_session_history=self.get_message_history,
input_messages_key="input",
history_messages_key="history",
history_factory_config=[
ConfigurableFieldSpec(
id="thread_id", # must match get_message_history kwarg
annotation=int,
name="Thread ID",
description="Unique identifier for the chat thread / conversation / session.",
default=None,
is_shared=True,
),
],
).with_config(
{"configurable": {"thread_id": thread_id}},
run_name="agent_with_chat_history",
)

return agent_with_chat_history

@with_cast_id
def invoke(self, *args: Any, thread_id: Any | None, **kwargs: Any) -> dict:
"""Invoke the assistant Langchain chain with the given arguments and keyword arguments.\n
"""Invoke the assistant Langchain graph with the given arguments and keyword arguments.\n
This is the lower-level method to run the assistant.\n
The chain is created by the `as_chain` method.\n
The graph is created by the `as_graph` method.\n
Args:
*args: Positional arguments to pass to the chain.
*args: Positional arguments to pass to the graph.
Make sure to include a `dict` like `{"input": "user message"}`.
thread_id (Any | None): The thread ID for the chat message history.
If `None`, an in-memory chat message history is used.
**kwargs: Keyword arguments to pass to the chain.
**kwargs: Keyword arguments to pass to the graph.
Returns:
dict: The output of the assistant chain,
dict: The output of the assistant graph,
structured like `{"output": "assistant response", "history": ...}`.
"""
if app_settings.USE_LANGGRAPH:
chain = self.as_graph(thread_id)
else:
chain = self.as_chain(thread_id)
return chain.invoke(*args, **kwargs)
graph = self.as_graph(thread_id)
return graph.invoke(*args, **kwargs)

@with_cast_id
def run(self, message: str, thread_id: Any | None = None, **kwargs: Any) -> str:
Expand All @@ -675,7 +519,7 @@ def run(self, message: str, thread_id: Any | None = None, **kwargs: Any) -> str:
message (str): The user message to pass to the assistant.
thread_id (Any | None): The thread ID for the chat message history.
If `None`, an in-memory chat message history is used.
**kwargs: Additional keyword arguments to pass to the chain.
**kwargs: Additional keyword arguments to pass to the graph.
Returns:
str: The assistant response to the user message.
Expand Down
2 changes: 1 addition & 1 deletion django_ai_assistant/helpers/use_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def create_message(
content (Any): Message content, usually a string
request (HttpRequest | None): Current request, if any
Returns:
dict: The output of the assistant chain,
dict: The output of the assistant,
structured like `{"output": "assistant response", "history": ...}`
Raises:
AIUserNotAllowedError: If user is not allowed to create messages in the thread
Expand Down
16 changes: 5 additions & 11 deletions docs/tutorial.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
search:
boost: 2
boost: 2
---

# Tutorial
Expand Down Expand Up @@ -274,7 +274,7 @@ urlpatterns = [
path("ai-assistant/", include("django_ai_assistant.urls")),
...
]
```
```

The built-in API supports retrieval of Assistants info, as well as CRUD for Threads and Messages.
It has a OpenAPI schema that you can explore at `http://localhost:8000/ai-assistant/docs`, when running your project locally.
Expand Down Expand Up @@ -415,15 +415,13 @@ shows an example of a composed AI Assistant that's able to recommend movies and
### Retrieval Augmented Generation (RAG)

You can use RAG in your AI Assistants. RAG means using a retriever to fetch chunks of textual data from a pre-existing DB to give
context to the LLM. This context goes into the `{context}` placeholder in the `instructions` string, namely the system prompt.
This means the LLM will have access to a context your retriever logic provides when generating the response,
context to the LLM. This means the LLM will have access to a context your retriever logic provides when generating the response,
thereby improving the quality of the response by avoiding generic or off-topic answers.

For this to work, your must do the following in your AI Assistant:

1. Add a `{context}` placeholder in the `instructions` string;
2. Add `has_rag = True` as a class attribute;
3. Override the `get_retriever` method to return a [Langchain Retriever](https://python.langchain.com/v0.2/docs/how_to/#retrievers).
1. Add `has_rag = True` as a class attribute;
2. Override the `get_retriever` method to return a [Langchain Retriever](https://python.langchain.com/v0.2/docs/how_to/#retrievers).

For example:

Expand All @@ -436,10 +434,6 @@ class DocsAssistant(AIAssistant):
instructions = (
"You are an assistant for answering questions related to the provided context. "
"Use the following pieces of retrieved context to answer the user's question. "
"\n\n"
"---START OF CONTEXT---\n"
"{context}"
"---END OF CONTEXT---\n"
)
model = "gpt-4o"
has_rag = True
Expand Down
1 change: 0 additions & 1 deletion example/example/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@

# django-ai-assistant

AI_ASSISTANT_USE_LANGGRAPH = True
AI_ASSISTANT_INIT_API_FN = "django_ai_assistant.api.views.init_api"
AI_ASSISTANT_CAN_CREATE_THREAD_FN = "django_ai_assistant.permissions.allow_all"
AI_ASSISTANT_CAN_VIEW_THREAD_FN = "django_ai_assistant.permissions.owns_thread"
Expand Down
4 changes: 0 additions & 4 deletions example/rag/ai_assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ class DjangoDocsAssistant(AIAssistant):
"Use the following pieces of retrieved context from Django's documentation to answer "
"the user's question. If you don't know the answer, say that you don't know. "
"Use three sentences maximum and keep the answer concise."
"\n\n"
"---START OF CONTEXT---\n"
"{context}"
"---END OF CONTEXT---\n"
)
model = "gpt-4o"
has_rag = True
Expand Down
1 change: 0 additions & 1 deletion tests/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@
# django-ai-assistant

# NOTE: set a OPENAI_API_KEY on .env.tests file at root when updating the VCRs.
AI_ASSISTANT_USE_LANGGRAPH = True
AI_ASSISTANT_INIT_API_FN = "django_ai_assistant.api.views.init_api"
AI_ASSISTANT_CAN_CREATE_THREAD_FN = "django_ai_assistant.permissions.allow_all"
AI_ASSISTANT_CAN_VIEW_THREAD_FN = "django_ai_assistant.permissions.owns_thread"
Expand Down
Loading

0 comments on commit 81d0ff6

Please sign in to comment.