Skip to content

Commit

Permalink
removes as_chain assistant function
Browse files Browse the repository at this point in the history
  • Loading branch information
filipeximenes committed Sep 11, 2024
1 parent 8acb2e7 commit f675e4d
Showing 1 changed file with 0 additions and 94 deletions.
94 changes: 0 additions & 94 deletions django_ai_assistant/helpers/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@
import re
from typing import Annotated, Any, ClassVar, Sequence, TypedDict, cast

from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad.tools import (
format_to_tool_messages,
)
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
from langchain.chains.combine_documents.base import (
DEFAULT_DOCUMENT_PROMPT,
DEFAULT_DOCUMENT_SEPARATOR,
Expand Down Expand Up @@ -38,12 +33,9 @@
RetrieverOutput,
)
from langchain_core.runnables import (
ConfigurableFieldSpec,
Runnable,
RunnableBranch,
RunnablePassthrough,
)
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph
Expand Down Expand Up @@ -557,91 +549,6 @@ def record_response(state: AgentState):

return workflow.compile()

@with_cast_id
def as_chain(self, thread_id: Any | None) -> Runnable[dict, dict]:
"""Create the Langchain chain for the assistant.\n
This chain is an agent that supports chat history, tool calling, and RAG (if `has_rag=True`).\n
`as_chain` uses many other methods to create the chain.\n
Prefer to override the other methods to customize the chain for the assistant.
Only override this method if you need to customize the chain at a lower level.
The chain input is a dictionary with the key `"input"` containing the user message.\n
The chain output is a dictionary with the key `"output"` containing the assistant response,
along with the key `"history"` containing the previous chat history.
Args:
thread_id (Any | None): The thread ID for the chat message history.
If `None`, an in-memory chat message history is used.
Returns:
Runnable[dict, dict]: The Langchain chain for the assistant.
"""
# Based on:
# - https://python.langchain.com/v0.2/docs/how_to/qa_chat_history_how_to/
# - https://python.langchain.com/v0.2/docs/how_to/migrate_agent/
# TODO: use langgraph instead?
llm = self.get_llm()
tools = self.get_tools()
prompt = self.get_prompt_template()
tools = cast(Sequence[BaseTool], tools)
if tools:
llm_with_tools = llm.bind_tools(tools)
else:
llm_with_tools = llm
chain = (
# based on create_tool_calling_agent:
RunnablePassthrough.assign(
agent_scratchpad=lambda x: format_to_tool_messages(x["intermediate_steps"])
).with_config(run_name="format_to_tool_messages")
)

if self.has_rag:
# based on create_retrieval_chain:
retriever = self.get_history_aware_retriever()
chain = chain | RunnablePassthrough.assign(
docs=retriever.with_config(run_name="retrieve_documents"),
)

# based on create_stuff_documents_chain:
document_separator = self.get_document_separator()
document_prompt = self.get_document_prompt()
context_placeholder = self.get_context_placeholder()
chain = chain | RunnablePassthrough.assign(
**{
context_placeholder: lambda x: document_separator.join(
format_document(doc, document_prompt) for doc in x["docs"]
)
}
).with_config(run_name="format_input_docs")

chain = chain | prompt | llm_with_tools | ToolsAgentOutputParser()

agent_executor = AgentExecutor(
agent=chain, # pyright: ignore[reportArgumentType]
tools=tools,
)
agent_with_chat_history = RunnableWithMessageHistory(
agent_executor, # pyright: ignore[reportArgumentType]
get_session_history=self.get_message_history,
input_messages_key="input",
history_messages_key="history",
history_factory_config=[
ConfigurableFieldSpec(
id="thread_id", # must match get_message_history kwarg
annotation=int,
name="Thread ID",
description="Unique identifier for the chat thread / conversation / session.",
default=None,
is_shared=True,
),
],
).with_config(
{"configurable": {"thread_id": thread_id}},
run_name="agent_with_chat_history",
)

return agent_with_chat_history

@with_cast_id
def invoke(self, *args: Any, thread_id: Any | None, **kwargs: Any) -> dict:
"""Invoke the assistant Langchain chain with the given arguments and keyword arguments.\n
Expand All @@ -659,7 +566,6 @@ def invoke(self, *args: Any, thread_id: Any | None, **kwargs: Any) -> dict:
dict: The output of the assistant chain,
structured like `{"output": "assistant response", "history": ...}`.
"""
# chain = self.as_chain(thread_id)
chain = self.as_graph(thread_id)
return chain.invoke(*args, **kwargs)

Expand Down

0 comments on commit f675e4d

Please sign in to comment.