Skip to content

Commit

Permalink
removes as_chain code and references
Browse files Browse the repository at this point in the history
  • Loading branch information
filipeximenes committed Sep 11, 2024
1 parent ae0dbe8 commit 76c4dc1
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 102 deletions.
1 change: 0 additions & 1 deletion django_ai_assistant/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

DEFAULTS = {
"INIT_API_FN": "django_ai_assistant.api.views.init_api",
"USE_LANGGRAPH": False,
"CAN_CREATE_THREAD_FN": "django_ai_assistant.permissions.allow_all",
"CAN_VIEW_THREAD_FN": "django_ai_assistant.permissions.owns_thread",
"CAN_UPDATE_THREAD_FN": "django_ai_assistant.permissions.owns_thread",
Expand Down
101 changes: 2 additions & 99 deletions django_ai_assistant/helpers/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@
import re
from typing import Annotated, Any, ClassVar, Sequence, TypedDict, cast

from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad.tools import (
format_to_tool_messages,
)
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
from langchain.chains.combine_documents.base import (
DEFAULT_DOCUMENT_PROMPT,
DEFAULT_DOCUMENT_SEPARATOR,
Expand Down Expand Up @@ -38,19 +33,15 @@
RetrieverOutput,
)
from langchain_core.runnables import (
ConfigurableFieldSpec,
Runnable,
RunnableBranch,
RunnablePassthrough,
)
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode

from django_ai_assistant.conf import app_settings
from django_ai_assistant.decorators import with_cast_id
from django_ai_assistant.exceptions import (
AIAssistantMisconfiguredError,
Expand Down Expand Up @@ -558,96 +549,11 @@ def record_response(state: AgentState):

return workflow.compile()

@with_cast_id
def as_chain(self, thread_id: Any | None) -> Runnable[dict, dict]:
"""Create the Langchain chain for the assistant.\n
This chain is an agent that supports chat history, tool calling, and RAG (if `has_rag=True`).\n
`as_chain` uses many other methods to create the chain.\n
Prefer to override the other methods to customize the chain for the assistant.
Only override this method if you need to customize the chain at a lower level.
The chain input is a dictionary with the key `"input"` containing the user message.\n
The chain output is a dictionary with the key `"output"` containing the assistant response,
along with the key `"history"` containing the previous chat history.
Args:
thread_id (Any | None): The thread ID for the chat message history.
If `None`, an in-memory chat message history is used.
Returns:
Runnable[dict, dict]: The Langchain chain for the assistant.
"""
# Based on:
# - https://python.langchain.com/v0.2/docs/how_to/qa_chat_history_how_to/
# - https://python.langchain.com/v0.2/docs/how_to/migrate_agent/
# TODO: use langgraph instead?
llm = self.get_llm()
tools = self.get_tools()
prompt = self.get_prompt_template()
tools = cast(Sequence[BaseTool], tools)
if tools:
llm_with_tools = llm.bind_tools(tools)
else:
llm_with_tools = llm
chain = (
# based on create_tool_calling_agent:
RunnablePassthrough.assign(
agent_scratchpad=lambda x: format_to_tool_messages(x["intermediate_steps"])
).with_config(run_name="format_to_tool_messages")
)

if self.has_rag:
# based on create_retrieval_chain:
retriever = self.get_history_aware_retriever()
chain = chain | RunnablePassthrough.assign(
docs=retriever.with_config(run_name="retrieve_documents"),
)

# based on create_stuff_documents_chain:
document_separator = self.get_document_separator()
document_prompt = self.get_document_prompt()
context_placeholder = self.get_context_placeholder()
chain = chain | RunnablePassthrough.assign(
**{
context_placeholder: lambda x: document_separator.join(
format_document(doc, document_prompt) for doc in x["docs"]
)
}
).with_config(run_name="format_input_docs")

chain = chain | prompt | llm_with_tools | ToolsAgentOutputParser()

agent_executor = AgentExecutor(
agent=chain, # pyright: ignore[reportArgumentType]
tools=tools,
)
agent_with_chat_history = RunnableWithMessageHistory(
agent_executor, # pyright: ignore[reportArgumentType]
get_session_history=self.get_message_history,
input_messages_key="input",
history_messages_key="history",
history_factory_config=[
ConfigurableFieldSpec(
id="thread_id", # must match get_message_history kwarg
annotation=int,
name="Thread ID",
description="Unique identifier for the chat thread / conversation / session.",
default=None,
is_shared=True,
),
],
).with_config(
{"configurable": {"thread_id": thread_id}},
run_name="agent_with_chat_history",
)

return agent_with_chat_history

@with_cast_id
def invoke(self, *args: Any, thread_id: Any | None, **kwargs: Any) -> dict:
"""Invoke the assistant Langchain chain with the given arguments and keyword arguments.\n
This is the lower-level method to run the assistant.\n
The chain is created by the `as_chain` method.\n
The chain is created by the `as_graph` method.\n
Args:
*args: Positional arguments to pass to the chain.
Expand All @@ -660,10 +566,7 @@ def invoke(self, *args: Any, thread_id: Any | None, **kwargs: Any) -> dict:
dict: The output of the assistant chain,
structured like `{"output": "assistant response", "history": ...}`.
"""
if app_settings.USE_LANGGRAPH:
chain = self.as_graph(thread_id)
else:
chain = self.as_chain(thread_id)
chain = self.as_graph(thread_id)
return chain.invoke(*args, **kwargs)

@with_cast_id
Expand Down
1 change: 0 additions & 1 deletion example/example/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@

# django-ai-assistant

AI_ASSISTANT_USE_LANGGRAPH = True
AI_ASSISTANT_INIT_API_FN = "django_ai_assistant.api.views.init_api"
AI_ASSISTANT_CAN_CREATE_THREAD_FN = "django_ai_assistant.permissions.allow_all"
AI_ASSISTANT_CAN_VIEW_THREAD_FN = "django_ai_assistant.permissions.owns_thread"
Expand Down
1 change: 0 additions & 1 deletion tests/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@
# django-ai-assistant

# NOTE: set a OPENAI_API_KEY on .env.tests file at root when updating the VCRs.
AI_ASSISTANT_USE_LANGGRAPH = True
AI_ASSISTANT_INIT_API_FN = "django_ai_assistant.api.views.init_api"
AI_ASSISTANT_CAN_CREATE_THREAD_FN = "django_ai_assistant.permissions.allow_all"
AI_ASSISTANT_CAN_VIEW_THREAD_FN = "django_ai_assistant.permissions.owns_thread"
Expand Down

0 comments on commit 76c4dc1

Please sign in to comment.