Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[New Feature] 支持 stream agent 流式返回step和message #345

Merged
merged 8 commits into from
May 7, 2024
105 changes: 104 additions & 1 deletion erniebot-agent/src/erniebot_agent/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from typing import (
Any,
AsyncIterator,
Dict,
Final,
Iterable,
Expand All @@ -20,7 +21,13 @@
from erniebot_agent.agents.callback.default import get_default_callbacks
from erniebot_agent.agents.callback.handlers.base import CallbackHandler
from erniebot_agent.agents.mixins import GradioMixin
from erniebot_agent.agents.schema import AgentResponse, LLMResponse, ToolResponse
from erniebot_agent.agents.schema import (
DEFAULT_FINISH_STEP,
AgentResponse,
AgentStep,
LLMResponse,
ToolResponse,
)
from erniebot_agent.chat_models.erniebot import BaseERNIEBot
from erniebot_agent.file import (
File,
Expand Down Expand Up @@ -131,6 +138,39 @@ async def run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Agen
await self._callback_manager.on_run_end(agent=self, response=agent_resp)
return agent_resp

@final
async def run_stream(
self, prompt: str, files: Optional[Sequence[File]] = None
) -> AsyncIterator[Tuple[AgentStep, List[Message]]]:
"""Run the agent asynchronously.
xiabo0816 marked this conversation as resolved.
Show resolved Hide resolved

Args:
prompt: A natural language text describing the task that the agent
should perform.
files: A list of files that the agent can use to perform the task.
Returns:
Response AgentStep from the agent.
"""
if files:
await self._ensure_managed_files(files)
await self._callback_manager.on_run_start(agent=self, prompt=prompt)
try:
async for step, msg in self._run_stream(prompt, files):
yield (step, msg)
except BaseException as e:
await self._callback_manager.on_run_error(agent=self, error=e)
raise e
else:
await self._callback_manager.on_run_end(
agent=self,
response=AgentResponse(
text="Agent run stopped early.",
chat_history=self.memory.get_messages(),
steps=[step],
status="STOPPED",
),
)

@final
async def run_llm(
self,
Expand All @@ -156,6 +196,34 @@ async def run_llm(
await self._callback_manager.on_llm_end(agent=self, llm=self.llm, response=llm_resp)
return llm_resp

@final
async def run_llm_stream(
self,
messages: List[Message],
**llm_opts: Any,
) -> AsyncIterator[LLMResponse]:
"""Run the LLM asynchronously.
xiabo0816 marked this conversation as resolved.
Show resolved Hide resolved

Args:
messages: The input messages.
llm_opts: Options to pass to the LLM.

Returns:
Response from the LLM.
"""
llm_resp = None
await self._callback_manager.on_llm_start(agent=self, llm=self.llm, messages=messages)
try:
# The LLM will return an async iterator.
async for llm_resp in self._run_llm_stream(messages, **(llm_opts or {})):
yield llm_resp
except (Exception, KeyboardInterrupt) as e:
await self._callback_manager.on_llm_error(agent=self, llm=self.llm, error=e)
raise e
else:
await self._callback_manager.on_llm_end(agent=self, llm=self.llm, response=llm_resp)
return

@final
async def run_tool(self, tool_name: str, tool_args: str) -> ToolResponse:
"""Run the specified tool asynchronously.
Expand Down Expand Up @@ -221,6 +289,20 @@ def get_file_manager(self) -> FileManager:
async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> AgentResponse:
raise NotImplementedError

@abc.abstractmethod
async def _run_stream(
self, prompt: str, files: Optional[Sequence[File]] = None
) -> AsyncIterator[Tuple[AgentStep, List[Message]]]:
"""
Abstract asynchronous generator method that should be implemented by subclasses.
This method should yield a sequence of (AgentStep, List[Message]) tuples based on the given
prompt and optionally accompanying files.
"""
if False:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是为什么呀

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

特别不好意思,这个是对mypy的妥协:

这里直接pass或者raise都会mypy编译器报错,原因应该是返回AsyncIterator时,如果内部没有yield就认为返回值是Coroutine,从而导致和子类重载函数返回值类型(FunctionAgent._run_stream)不同而报错;由于我一直没能找到通过mypy检测的抽象方法@abc.abstractmethod async def _run_stream的返回值数据类型,就很丑陋的if False:了,哭哭

Copy link
Member

@Bobholamovic Bobholamovic Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

噢噢,好吧,那我觉得可以先保留这个(不过或许if False可以用if typing.TYPE_CHECKING代替,顺带可以在注释里加个HACK标记)~

# This conditional block is strictly for static type-checking purposes (e.g., mypy) and will not be executed.
only_for_mypy_type_check: Tuple[AgentStep, List[Message]] = (DEFAULT_FINISH_STEP, [])
yield only_for_mypy_type_check

async def _run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse:
for reserved_opt in ("stream", "system", "plugins"):
if reserved_opt in opts:
Expand All @@ -241,6 +323,27 @@ async def _run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse:
llm_ret = await self.llm.chat(messages, stream=False, functions=functions, **opts)
return LLMResponse(message=llm_ret)

async def _run_llm_stream(self, messages: List[Message], **opts: Any) -> AsyncIterator[LLMResponse]:
for reserved_opt in ("stream", "system", "plugins"):
if reserved_opt in opts:
raise TypeError(f"`{reserved_opt}` should not be set.")

if "functions" not in opts:
functions = self._tool_manager.get_tool_schemas()
else:
functions = opts.pop("functions")

if hasattr(self.llm, "system"):
_logger.warning(
"The `system` message has already been set in the agent;"
"the `system` message configured in ERNIEBot will become ineffective."
)
opts["system"] = self.system.content if self.system is not None else None
opts["plugins"] = self._plugins
llm_ret = await self.llm.chat(messages, stream=True, functions=functions, **opts)
async for msg in llm_ret:
yield LLMResponse(message=msg)

async def _run_tool(self, tool: BaseTool, tool_args: str) -> ToolResponse:
parsed_tool_args = self._parse_tool_args(tool_args)
file_manager = self.get_file_manager()
Expand Down
138 changes: 125 additions & 13 deletions erniebot-agent/src/erniebot_agent/agents/function_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,16 @@
# limitations under the License.

import logging
from typing import Final, Iterable, List, Optional, Sequence, Tuple, Union
from typing import (
AsyncIterator,
Final,
Iterable,
List,
Optional,
Sequence,
Tuple,
Union,
)

from erniebot_agent.agents.agent import Agent
from erniebot_agent.agents.callback.callback_manager import CallbackManager
Expand All @@ -31,7 +40,12 @@
from erniebot_agent.chat_models.erniebot import BaseERNIEBot
from erniebot_agent.file import File, FileManager
from erniebot_agent.memory import Memory
from erniebot_agent.memory.messages import FunctionMessage, HumanMessage, Message
from erniebot_agent.memory.messages import (
AIMessage,
FunctionMessage,
HumanMessage,
Message,
)
from erniebot_agent.tools.base import BaseTool
from erniebot_agent.tools.tool_manager import ToolManager

Expand Down Expand Up @@ -136,7 +150,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
chat_history.append(run_input)

for tool in self._first_tools:
curr_step, new_messages = await self._step(chat_history, selected_tool=tool)
curr_step, new_messages = await self._first_tool_step(chat_history, selected_tool=tool)
if not isinstance(curr_step, EndStep):
chat_history.extend(new_messages)
num_steps_taken += 1
Expand Down Expand Up @@ -167,21 +181,119 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
response = self._create_stopped_response(chat_history, steps_taken)
return response

async def _step(
async def _first_tool_step(
xiabo0816 marked this conversation as resolved.
Show resolved Hide resolved
self, chat_history: List[Message], selected_tool: Optional[BaseTool] = None
) -> Tuple[AgentStep, List[Message]]:
new_messages: List[Message] = []
input_messages = self.memory.get_messages() + chat_history
if selected_tool is not None:
tool_choice = {"type": "function", "function": {"name": selected_tool.tool_name}}
llm_resp = await self.run_llm(
messages=input_messages,
functions=[selected_tool.function_call_schema()], # only regist one tool
tool_choice=tool_choice,
)
else:
if selected_tool is None:
llm_resp = await self.run_llm(messages=input_messages)
return await self._schema_format(llm_resp, chat_history)

tool_choice = {"type": "function", "function": {"name": selected_tool.tool_name}}
llm_resp = await self.run_llm(
messages=input_messages,
functions=[selected_tool.function_call_schema()], # only regist one tool
tool_choice=tool_choice,
)
return await self._schema_format(llm_resp, chat_history)

async def _step(self, chat_history: List[Message]) -> Tuple[AgentStep, List[Message]]:
"""Run a step of the agent.
Args:
chat_history: The chat history to provide to the agent.
Returns:
A tuple of an agent step and a list of new messages.
"""
input_messages = self.memory.get_messages() + chat_history
llm_resp = await self.run_llm(messages=input_messages)
return await self._schema_format(llm_resp, chat_history)

async def _step_stream(
self, chat_history: List[Message]
) -> AsyncIterator[Tuple[AgentStep, List[Message]]]:
"""Run a step of the agent in streaming mode.
Args:
chat_history: The chat history to provide to the agent.
Returns:
An async iterator that yields a tuple of an agent step and a list ofnew messages.
"""
input_messages = self.memory.get_messages() + chat_history
async for llm_resp in self.run_llm_stream(messages=input_messages):
yield await self._schema_format(llm_resp, chat_history)

async def _run_stream(
self, prompt: str, files: Optional[Sequence[File]] = None
) -> AsyncIterator[Tuple[AgentStep, List[Message]]]:
"""Run the agent with the given prompt and files in streaming mode.
Args:
prompt: The prompt for the agent to run.
files: A list of files for the agent to use. If `None`, use an empty
list.
Returns:
If `stream` is `False`, an agent response object. If `stream` is
`True`, an async iterator that yields agent steps one by one.
"""
chat_history: List[Message] = []
steps_taken: List[AgentStep] = []

run_input = await HumanMessage.create_with_files(
prompt, files or [], include_file_urls=self.file_needs_url
)

num_steps_taken = 0
chat_history.append(run_input)

for tool in self._first_tools:
curr_step, new_messages = await self._first_tool_step(chat_history, selected_tool=tool)
if not isinstance(curr_step, EndStep):
chat_history.extend(new_messages)
num_steps_taken += 1
steps_taken.append(curr_step)
else:
# If tool choice not work, skip this round
_logger.warning(f"Selected tool [{tool.tool_name}] not work")

is_finished = False
new_messages = []
end_step_msgs = []
while is_finished is False:
# IMPORTANT~! We use following code to get the response from LLM
# When finish_reason is fuction_call, run_llm_stream return all info in one step, but
# When finish_reason is normal chat, run_llm_stream return info in multiple steps.
async for curr_step, new_messages in self._step_stream(chat_history):
if isinstance(curr_step, ToolStep):
steps_taken.append(curr_step)
yield curr_step, new_messages

elif isinstance(curr_step, PluginStep):
steps_taken.append(curr_step)
# 预留 调用了Plugin之后不结束的接口

# 此处为调用了Plugin之后直接结束的Plugin
curr_step = DEFAULT_FINISH_STEP
yield curr_step, new_messages

elif isinstance(curr_step, EndStep):
is_finished = True
end_step_msgs.extend(new_messages)
yield curr_step, new_messages
else:
raise RuntimeError("Invalid step type")
chat_history.extend(new_messages)

self.memory.add_message(run_input)
end_step_msg = AIMessage(content="".join([item.content for item in end_step_msgs]))
self.memory.add_message(end_step_msg)

async def _schema_format(self, llm_resp, chat_history) -> Tuple[AgentStep, List[Message]]:
xiabo0816 marked this conversation as resolved.
Show resolved Hide resolved
"""Convert the LLM response to the agent response schema.
Args:
llm_resp: The LLM response to convert.
chat_history: The chat history to provide to the agent.
Returns:
A tuple of an agent step and a list of new messages.
"""
new_messages: List[Message] = []
output_message = llm_resp.message # AIMessage
new_messages.append(output_message)
if output_message.function_call is not None:
Expand Down
Loading
Loading