From d3cfef2caf8147b89d030239e6fdc2bcacce521d Mon Sep 17 00:00:00 2001 From: xiabo Date: Fri, 26 Apr 2024 17:38:05 +0800 Subject: [PATCH 1/8] =?UTF-8?q?stream=20agent=20=E6=B5=81=E5=BC=8F?= =?UTF-8?q?=E8=BF=94=E5=9B=9Estep=E5=92=8Cmessage?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/erniebot_agent/agents/agent.py | 84 +++++++++++- .../erniebot_agent/agents/function_agent.py | 126 ++++++++++++++++-- .../agents/test_functional_agent_stream.py | 97 ++++++++++++++ 3 files changed, 292 insertions(+), 15 deletions(-) create mode 100644 erniebot-agent/tests/integration_tests/agents/test_functional_agent_stream.py diff --git a/erniebot-agent/src/erniebot_agent/agents/agent.py b/erniebot-agent/src/erniebot_agent/agents/agent.py index 2768eaea..527c7278 100644 --- a/erniebot-agent/src/erniebot_agent/agents/agent.py +++ b/erniebot-agent/src/erniebot_agent/agents/agent.py @@ -2,6 +2,7 @@ import json import logging from typing import ( + AsyncIterator, Any, Dict, Final, @@ -20,7 +21,7 @@ from erniebot_agent.agents.callback.default import get_default_callbacks from erniebot_agent.agents.callback.handlers.base import CallbackHandler from erniebot_agent.agents.mixins import GradioMixin -from erniebot_agent.agents.schema import AgentResponse, LLMResponse, ToolResponse +from erniebot_agent.agents.schema import AgentResponse, LLMResponse, ToolResponse, AgentStep from erniebot_agent.chat_models.erniebot import BaseERNIEBot from erniebot_agent.file import ( File, @@ -131,6 +132,30 @@ async def run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Agen await self._callback_manager.on_run_end(agent=self, response=agent_resp) return agent_resp + @final + async def run_stream(self, prompt: str, files: Optional[Sequence[File]] = None) -> AsyncIterator[AgentStep]: + """Run the agent asynchronously. + + Args: + prompt: A natural language text describing the task that the agent + should perform. + files: A list of files that the agent can use to perform the task. + Returns: + Response AgentStep from the agent. + """ + if files: + await self._ensure_managed_files(files) + await self._callback_manager.on_run_start(agent=self, prompt=prompt) + try: + async for step in self._run_stream(prompt, files): + yield step + except BaseException as e: + await self._callback_manager.on_run_error(agent=self, error=e) + raise e + else: + await self._callback_manager.on_run_end(agent=self, response=None) + return + @final async def run_llm( self, @@ -155,6 +180,36 @@ async def run_llm( else: await self._callback_manager.on_llm_end(agent=self, llm=self.llm, response=llm_resp) return llm_resp + + + @final + async def run_llm_stream( + self, + messages: List[Message], + **llm_opts: Any, + ) -> AsyncIterator[LLMResponse]: + """Run the LLM asynchronously. + + Args: + messages: The input messages. + llm_opts: Options to pass to the LLM. + + Returns: + Response from the LLM. + """ + llm_resp = None + await self._callback_manager.on_llm_start(agent=self, llm=self.llm, messages=messages) + try: + # The LLM will return an async iterator. + async for llm_resp in self._run_llm_stream(messages, **(llm_opts or {})): + yield llm_resp + except (Exception, KeyboardInterrupt) as e: + await self._callback_manager.on_llm_error(agent=self, llm=self.llm, error=e) + raise e + else: + await self._callback_manager.on_llm_end(agent=self, llm=self.llm, response=llm_resp) + return + @final async def run_tool(self, tool_name: str, tool_args: str) -> ToolResponse: @@ -221,6 +276,10 @@ def get_file_manager(self) -> FileManager: async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> AgentResponse: raise NotImplementedError + @abc.abstractmethod + async def _run_stream(self, prompt: str, files: Optional[Sequence[File]] = None) -> AsyncIterator[AgentStep]: + raise NotImplementedError + async def _run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse: for reserved_opt in ("stream", "system", "plugins"): if reserved_opt in opts: @@ -240,6 +299,29 @@ async def _run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse: opts["plugins"] = self._plugins llm_ret = await self.llm.chat(messages, stream=False, functions=functions, **opts) return LLMResponse(message=llm_ret) + + + async def _run_llm_stream(self, messages: List[Message], **opts: Any) -> AsyncIterator[LLMResponse]: + for reserved_opt in ("stream", "system", "plugins"): + if reserved_opt in opts: + raise TypeError(f"`{reserved_opt}` should not be set.") + + if "functions" not in opts: + functions = self._tool_manager.get_tool_schemas() + else: + functions = opts.pop("functions") + + if hasattr(self.llm, "system"): + _logger.warning( + "The `system` message has already been set in the agent;" + "the `system` message configured in ERNIEBot will become ineffective." + ) + opts["system"] = self.system.content if self.system is not None else None + opts["plugins"] = self._plugins + llm_ret = await self.llm.chat(messages, stream=True, functions=functions, **opts) + async for msg in llm_ret: + yield LLMResponse(message=msg) + async def _run_tool(self, tool: BaseTool, tool_args: str) -> ToolResponse: parsed_tool_args = self._parse_tool_args(tool_args) diff --git a/erniebot-agent/src/erniebot_agent/agents/function_agent.py b/erniebot-agent/src/erniebot_agent/agents/function_agent.py index 5e2d7746..378e33c8 100644 --- a/erniebot-agent/src/erniebot_agent/agents/function_agent.py +++ b/erniebot-agent/src/erniebot_agent/agents/function_agent.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Final, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Final, Iterable, List, Optional, Sequence, Tuple, Union, AsyncIterator from erniebot_agent.agents.agent import Agent from erniebot_agent.agents.callback.callback_manager import CallbackManager @@ -31,7 +31,7 @@ from erniebot_agent.chat_models.erniebot import BaseERNIEBot from erniebot_agent.file import File, FileManager from erniebot_agent.memory import Memory -from erniebot_agent.memory.messages import FunctionMessage, HumanMessage, Message +from erniebot_agent.memory.messages import FunctionMessage, HumanMessage, Message, AIMessage from erniebot_agent.tools.base import BaseTool from erniebot_agent.tools.tool_manager import ToolManager @@ -136,7 +136,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age chat_history.append(run_input) for tool in self._first_tools: - curr_step, new_messages = await self._step(chat_history, selected_tool=tool) + curr_step, new_messages = await self._first_tool_step(chat_history, selected_tool=tool) if not isinstance(curr_step, EndStep): chat_history.extend(new_messages) num_steps_taken += 1 @@ -166,22 +166,120 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age num_steps_taken += 1 response = self._create_stopped_response(chat_history, steps_taken) return response + - async def _step( - self, chat_history: List[Message], selected_tool: Optional[BaseTool] = None + async def _first_tool_step( + self, chat_history: List[Message], selected_tool: BaseTool = None ) -> Tuple[AgentStep, List[Message]]: - new_messages: List[Message] = [] input_messages = self.memory.get_messages() + chat_history - if selected_tool is not None: - tool_choice = {"type": "function", "function": {"name": selected_tool.tool_name}} - llm_resp = await self.run_llm( - messages=input_messages, - functions=[selected_tool.function_call_schema()], # only regist one tool - tool_choice=tool_choice, - ) - else: + if selected_tool is None: llm_resp = await self.run_llm(messages=input_messages) + return await self._schema_format(llm_resp, chat_history) + + tool_choice = {"type": "function", "function": {"name": selected_tool.tool_name}} + llm_resp = await self.run_llm( + messages=input_messages, + functions=[selected_tool.function_call_schema()], # only regist one tool + tool_choice=tool_choice, + ) + return await self._schema_format(llm_resp, chat_history) + + + async def _step(self, chat_history: List[Message]) -> Tuple[AgentStep, List[Message]]: + """Run a step of the agent. + Args: + chat_history: The chat history to provide to the agent. + Returns: + A tuple of an agent step and a list of new messages. + """ + input_messages = self.memory.get_messages() + chat_history + llm_resp = await self.run_llm(messages=input_messages) + return await self._schema_format(llm_resp, chat_history) + + + async def _step_stream(self, chat_history: List[Message]) -> AsyncIterator[Tuple[AgentStep, List[Message]]]: + """Run a step of the agent in streaming mode. + Args: + chat_history: The chat history to provide to the agent. + Returns: + An async iterator that yields a tuple of an agent step and a list ofnew messages. + """ + input_messages = self.memory.get_messages() + chat_history + async for llm_resp in self.run_llm_stream(messages=input_messages): + yield await self._schema_format(llm_resp, chat_history) + + + async def _run_stream(self, prompt: str, files: Optional[Sequence[File]] = None) -> AsyncIterator[Tuple[AgentStep, List[Message]]]: + """Run the agent with the given prompt and files in streaming mode. + Args: + prompt: The prompt for the agent to run. + files: A list of files for the agent to use. If `None`, use an empty + list. + Returns: + If `stream` is `False`, an agent response object. If `stream` is + `True`, an async iterator that yields agent steps one by one. + """ + chat_history: List[Message] = [] + steps_taken: List[AgentStep] = [] + + run_input = await HumanMessage.create_with_files( + prompt, files or [], include_file_urls=self.file_needs_url + ) + + num_steps_taken = 0 + chat_history.append(run_input) + for tool in self._first_tools: + curr_step, new_messages = await self._first_tool_step(chat_history, selected_tool=tool) + if not isinstance(curr_step, EndStep): + chat_history.extend(new_messages) + num_steps_taken += 1 + steps_taken.append(curr_step) + else: + # If tool choice not work, skip this round + _logger.warning(f"Selected tool [{tool.tool_name}] not work") + + is_finished = False + curr_step = None + new_messages = [] + end_step_msgs = [] + while is_finished is False: + # IMPORTANT~! We use following code to get the response from LLM + # When finish_reason is fuction_call, run_llm_stream return all info in one step, but + # When finish_reason is normal chat, run_llm_stream return info in multiple steps. + async for curr_step, new_messages in self._step_stream(chat_history): + if isinstance(curr_step, ToolStep): + steps_taken.append(curr_step) + yield curr_step, new_messages + + elif isinstance(curr_step, PluginStep): + steps_taken.append(curr_step) + # 预留 调用了Plugin之后不结束的接口 + + # 此处为调用了Plugin之后直接结束的Plugin + curr_step = DEFAULT_FINISH_STEP + yield curr_step, new_messages + + if isinstance(curr_step, EndStep): + is_finished = True + end_step_msgs.extend(new_messages) + yield curr_step, new_messages + chat_history.extend(new_messages) + + self.memory.add_message(run_input) + end_step_msg = AIMessage(content="".join([item.content for item in end_step_msgs])) + self.memory.add_message(end_step_msg) + + + async def _schema_format(self, llm_resp, chat_history): + """Convert the LLM response to the agent response schema. + Args: + llm_resp: The LLM response to convert. + chat_history: The chat history to provide to the agent. + Returns: + A tuple of an agent step and a list of new messages. + """ + new_messages: List[Message] = [] output_message = llm_resp.message # AIMessage new_messages.append(output_message) if output_message.function_call is not None: diff --git a/erniebot-agent/tests/integration_tests/agents/test_functional_agent_stream.py b/erniebot-agent/tests/integration_tests/agents/test_functional_agent_stream.py new file mode 100644 index 00000000..5bfd5f44 --- /dev/null +++ b/erniebot-agent/tests/integration_tests/agents/test_functional_agent_stream.py @@ -0,0 +1,97 @@ +import json + +import pytest + +from erniebot_agent.agents import FunctionAgent +from erniebot_agent.chat_models import ERNIEBot +from erniebot_agent.memory import WholeMemory +from erniebot_agent.memory.messages import AIMessage, AIMessageChunk, FunctionMessage, HumanMessage +from erniebot_agent.tools.calculator_tool import CalculatorTool + +ONE_HIT_PROMPT = "1+4等于几?" +NO_HIT_PROMPT = "深圳今天天气怎么样?" + +# cd ERNIE-SDK/erniebot-agent/tests +# EB_AGENT_ACCESS_TOKEN= pytest integration_tests/agents/test_functional_agent_stream.py -s + +@pytest.fixture(scope="module") +def llm(): + return ERNIEBot(model="ernie-3.5") + + +@pytest.fixture(scope="module") +def tool(): + return CalculatorTool() + + +@pytest.fixture(scope="function") +def memory(): + return WholeMemory() + + +@pytest.mark.asyncio +async def test_function_agent_run_one_hit(llm, tool, memory): + agent = FunctionAgent(llm=llm, tools=[tool], memory=memory) + prompt = ONE_HIT_PROMPT + + run_logs = [] + async for step, msgs in agent.run_stream(prompt): + run_logs.append((step, msgs)) + + assert len(agent.memory.get_messages()) == 2 + assert isinstance(agent.memory.get_messages()[0], HumanMessage) + assert agent.memory.get_messages()[0].content == prompt + assert isinstance(run_logs[0][1][0], AIMessageChunk) + assert run_logs[0][1][0].function_call is not None + assert run_logs[0][1][0].function_call["name"] == tool.tool_name + assert isinstance(run_logs[0][1][1], FunctionMessage) + assert run_logs[0][1][1].name == run_logs[0][1][0].function_call["name"] + assert json.loads(run_logs[0][1][1].content) == {"formula_result": 5} + assert isinstance(agent.memory.get_messages()[1], AIMessage) + + steps = [step for step, msgs in run_logs if step.__class__.__name__ != "EndStep"] + assert len(steps) == 1 + assert steps[0].info["tool_name"] == tool.tool_name + +@pytest.mark.asyncio +async def test_function_agent_run_no_hit(llm, tool, memory): + agent = FunctionAgent(llm=llm, tools=[tool], memory=memory) + prompt = NO_HIT_PROMPT + + run_logs = [] + async for step, msgs in agent.run_stream(prompt): + run_logs.append((step, msgs)) + + assert len(agent.memory.get_messages()) == 2 + assert isinstance(agent.memory.get_messages()[0], HumanMessage) + assert agent.memory.get_messages()[0].content == prompt + assert isinstance(run_logs[0][1][0], AIMessage) + end_step_msg = "".join([msg[0].content for step, msg in run_logs if step.__class__.__name__ == "EndStep"]) + assert end_step_msg == agent.memory.get_messages()[1].content + + steps = [step for step, msgs in run_logs if step.__class__.__name__ != "EndStep"] + assert len(steps) == 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("prompt", [ONE_HIT_PROMPT, NO_HIT_PROMPT]) +async def test_function_agent_run_no_tool(llm, memory, prompt): + agent = FunctionAgent(llm=llm, tools=[], memory=memory) + + run_logs = [] + async for step, msgs in agent.run_stream(prompt): + run_logs.append((step, msgs)) + + assert len(agent.memory.get_messages()) == 2 + assert isinstance(agent.memory.get_messages()[0], HumanMessage) + assert agent.memory.get_messages()[0].content == prompt + assert isinstance(run_logs[0][1][0], AIMessage) + end_step_msg = "".join([msg[0].content for step, msg in run_logs if step.__class__.__name__ == "EndStep"]) + assert end_step_msg == agent.memory.get_messages()[1].content + + steps = [step for step, msgs in run_logs if step.__class__.__name__ != "EndStep"] + assert len(steps) == 0 + + + + From 3fb74a16f8d1d52b885bea4ec521d87ab385edd4 Mon Sep 17 00:00:00 2001 From: xiabo Date: Fri, 26 Apr 2024 20:14:05 +0800 Subject: [PATCH 2/8] =?UTF-8?q?=E4=BD=BF=E7=94=A8make=20format-check?= =?UTF-8?q?=E8=BF=9B=E8=A1=8C=E4=BB=A3=E7=A0=81=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/erniebot_agent/agents/agent.py | 14 +++++++------- .../erniebot_agent/agents/function_agent.py | 15 +++++++-------- .../agents/test_functional_agent_stream.py | 18 ++++++++++-------- 3 files changed, 24 insertions(+), 23 deletions(-) diff --git a/erniebot-agent/src/erniebot_agent/agents/agent.py b/erniebot-agent/src/erniebot_agent/agents/agent.py index 527c7278..0b3af77c 100644 --- a/erniebot-agent/src/erniebot_agent/agents/agent.py +++ b/erniebot-agent/src/erniebot_agent/agents/agent.py @@ -133,7 +133,9 @@ async def run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Agen return agent_resp @final - async def run_stream(self, prompt: str, files: Optional[Sequence[File]] = None) -> AsyncIterator[AgentStep]: + async def run_stream( + self, prompt: str, files: Optional[Sequence[File]] = None + ) -> AsyncIterator[AgentStep]: """Run the agent asynchronously. Args: @@ -180,7 +182,6 @@ async def run_llm( else: await self._callback_manager.on_llm_end(agent=self, llm=self.llm, response=llm_resp) return llm_resp - @final async def run_llm_stream( @@ -209,7 +210,6 @@ async def run_llm_stream( else: await self._callback_manager.on_llm_end(agent=self, llm=self.llm, response=llm_resp) return - @final async def run_tool(self, tool_name: str, tool_args: str) -> ToolResponse: @@ -277,9 +277,11 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age raise NotImplementedError @abc.abstractmethod - async def _run_stream(self, prompt: str, files: Optional[Sequence[File]] = None) -> AsyncIterator[AgentStep]: + async def _run_stream( + self, prompt: str, files: Optional[Sequence[File]] = None + ) -> AsyncIterator[AgentStep]: raise NotImplementedError - + async def _run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse: for reserved_opt in ("stream", "system", "plugins"): if reserved_opt in opts: @@ -299,7 +301,6 @@ async def _run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse: opts["plugins"] = self._plugins llm_ret = await self.llm.chat(messages, stream=False, functions=functions, **opts) return LLMResponse(message=llm_ret) - async def _run_llm_stream(self, messages: List[Message], **opts: Any) -> AsyncIterator[LLMResponse]: for reserved_opt in ("stream", "system", "plugins"): @@ -322,7 +323,6 @@ async def _run_llm_stream(self, messages: List[Message], **opts: Any) -> AsyncIt async for msg in llm_ret: yield LLMResponse(message=msg) - async def _run_tool(self, tool: BaseTool, tool_args: str) -> ToolResponse: parsed_tool_args = self._parse_tool_args(tool_args) file_manager = self.get_file_manager() diff --git a/erniebot-agent/src/erniebot_agent/agents/function_agent.py b/erniebot-agent/src/erniebot_agent/agents/function_agent.py index 378e33c8..5f231478 100644 --- a/erniebot-agent/src/erniebot_agent/agents/function_agent.py +++ b/erniebot-agent/src/erniebot_agent/agents/function_agent.py @@ -166,7 +166,6 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age num_steps_taken += 1 response = self._create_stopped_response(chat_history, steps_taken) return response - async def _first_tool_step( self, chat_history: List[Message], selected_tool: BaseTool = None @@ -184,7 +183,6 @@ async def _first_tool_step( ) return await self._schema_format(llm_resp, chat_history) - async def _step(self, chat_history: List[Message]) -> Tuple[AgentStep, List[Message]]: """Run a step of the agent. Args: @@ -196,8 +194,9 @@ async def _step(self, chat_history: List[Message]) -> Tuple[AgentStep, List[Mess llm_resp = await self.run_llm(messages=input_messages) return await self._schema_format(llm_resp, chat_history) - - async def _step_stream(self, chat_history: List[Message]) -> AsyncIterator[Tuple[AgentStep, List[Message]]]: + async def _step_stream( + self, chat_history: List[Message] + ) -> AsyncIterator[Tuple[AgentStep, List[Message]]]: """Run a step of the agent in streaming mode. Args: chat_history: The chat history to provide to the agent. @@ -208,8 +207,9 @@ async def _step_stream(self, chat_history: List[Message]) -> AsyncIterator[Tuple async for llm_resp in self.run_llm_stream(messages=input_messages): yield await self._schema_format(llm_resp, chat_history) - - async def _run_stream(self, prompt: str, files: Optional[Sequence[File]] = None) -> AsyncIterator[Tuple[AgentStep, List[Message]]]: + async def _run_stream( + self, prompt: str, files: Optional[Sequence[File]] = None + ) -> AsyncIterator[Tuple[AgentStep, List[Message]]]: """Run the agent with the given prompt and files in streaming mode. Args: prompt: The prompt for the agent to run. @@ -259,7 +259,7 @@ async def _run_stream(self, prompt: str, files: Optional[Sequence[File]] = None) # 此处为调用了Plugin之后直接结束的Plugin curr_step = DEFAULT_FINISH_STEP yield curr_step, new_messages - + if isinstance(curr_step, EndStep): is_finished = True end_step_msgs.extend(new_messages) @@ -270,7 +270,6 @@ async def _run_stream(self, prompt: str, files: Optional[Sequence[File]] = None) end_step_msg = AIMessage(content="".join([item.content for item in end_step_msgs])) self.memory.add_message(end_step_msg) - async def _schema_format(self, llm_resp, chat_history): """Convert the LLM response to the agent response schema. Args: diff --git a/erniebot-agent/tests/integration_tests/agents/test_functional_agent_stream.py b/erniebot-agent/tests/integration_tests/agents/test_functional_agent_stream.py index 5bfd5f44..1553a49c 100644 --- a/erniebot-agent/tests/integration_tests/agents/test_functional_agent_stream.py +++ b/erniebot-agent/tests/integration_tests/agents/test_functional_agent_stream.py @@ -14,6 +14,7 @@ # cd ERNIE-SDK/erniebot-agent/tests # EB_AGENT_ACCESS_TOKEN= pytest integration_tests/agents/test_functional_agent_stream.py -s + @pytest.fixture(scope="module") def llm(): return ERNIEBot(model="ernie-3.5") @@ -36,7 +37,7 @@ async def test_function_agent_run_one_hit(llm, tool, memory): run_logs = [] async for step, msgs in agent.run_stream(prompt): - run_logs.append((step, msgs)) + run_logs.append((step, msgs)) assert len(agent.memory.get_messages()) == 2 assert isinstance(agent.memory.get_messages()[0], HumanMessage) @@ -48,11 +49,12 @@ async def test_function_agent_run_one_hit(llm, tool, memory): assert run_logs[0][1][1].name == run_logs[0][1][0].function_call["name"] assert json.loads(run_logs[0][1][1].content) == {"formula_result": 5} assert isinstance(agent.memory.get_messages()[1], AIMessage) - + steps = [step for step, msgs in run_logs if step.__class__.__name__ != "EndStep"] assert len(steps) == 1 assert steps[0].info["tool_name"] == tool.tool_name + @pytest.mark.asyncio async def test_function_agent_run_no_hit(llm, tool, memory): agent = FunctionAgent(llm=llm, tools=[tool], memory=memory) @@ -66,7 +68,9 @@ async def test_function_agent_run_no_hit(llm, tool, memory): assert isinstance(agent.memory.get_messages()[0], HumanMessage) assert agent.memory.get_messages()[0].content == prompt assert isinstance(run_logs[0][1][0], AIMessage) - end_step_msg = "".join([msg[0].content for step, msg in run_logs if step.__class__.__name__ == "EndStep"]) + end_step_msg = "".join( + [msg[0].content for step, msg in run_logs if step.__class__.__name__ == "EndStep"] + ) assert end_step_msg == agent.memory.get_messages()[1].content steps = [step for step, msgs in run_logs if step.__class__.__name__ != "EndStep"] @@ -86,12 +90,10 @@ async def test_function_agent_run_no_tool(llm, memory, prompt): assert isinstance(agent.memory.get_messages()[0], HumanMessage) assert agent.memory.get_messages()[0].content == prompt assert isinstance(run_logs[0][1][0], AIMessage) - end_step_msg = "".join([msg[0].content for step, msg in run_logs if step.__class__.__name__ == "EndStep"]) + end_step_msg = "".join( + [msg[0].content for step, msg in run_logs if step.__class__.__name__ == "EndStep"] + ) assert end_step_msg == agent.memory.get_messages()[1].content steps = [step for step, msgs in run_logs if step.__class__.__name__ != "EndStep"] assert len(steps) == 0 - - - - From cc8b31cefead643f1e3814c74f43a7b08ec73418 Mon Sep 17 00:00:00 2001 From: xiabo Date: Sat, 27 Apr 2024 10:55:38 +0800 Subject: [PATCH 3/8] =?UTF-8?q?=E4=BD=BF=E7=94=A8make=20format-check?= =?UTF-8?q?=E6=A3=80=E6=9F=A5=E4=BB=A3=E7=A0=81=E3=80=81=E4=BD=BF=E7=94=A8?= =?UTF-8?q?make=20format=E4=BC=98=E5=8C=96=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/erniebot_agent/agents/agent.py | 9 +++++++-- .../erniebot_agent/agents/function_agent.py | 18 ++++++++++++++++-- .../agents/test_functional_agent_stream.py | 7 ++++++- 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/erniebot-agent/src/erniebot_agent/agents/agent.py b/erniebot-agent/src/erniebot_agent/agents/agent.py index 0b3af77c..d44f163a 100644 --- a/erniebot-agent/src/erniebot_agent/agents/agent.py +++ b/erniebot-agent/src/erniebot_agent/agents/agent.py @@ -2,8 +2,8 @@ import json import logging from typing import ( - AsyncIterator, Any, + AsyncIterator, Dict, Final, Iterable, @@ -21,7 +21,12 @@ from erniebot_agent.agents.callback.default import get_default_callbacks from erniebot_agent.agents.callback.handlers.base import CallbackHandler from erniebot_agent.agents.mixins import GradioMixin -from erniebot_agent.agents.schema import AgentResponse, LLMResponse, ToolResponse, AgentStep +from erniebot_agent.agents.schema import ( + AgentResponse, + AgentStep, + LLMResponse, + ToolResponse, +) from erniebot_agent.chat_models.erniebot import BaseERNIEBot from erniebot_agent.file import ( File, diff --git a/erniebot-agent/src/erniebot_agent/agents/function_agent.py b/erniebot-agent/src/erniebot_agent/agents/function_agent.py index 5f231478..867e85fd 100644 --- a/erniebot-agent/src/erniebot_agent/agents/function_agent.py +++ b/erniebot-agent/src/erniebot_agent/agents/function_agent.py @@ -13,7 +13,16 @@ # limitations under the License. import logging -from typing import Final, Iterable, List, Optional, Sequence, Tuple, Union, AsyncIterator +from typing import ( + AsyncIterator, + Final, + Iterable, + List, + Optional, + Sequence, + Tuple, + Union, +) from erniebot_agent.agents.agent import Agent from erniebot_agent.agents.callback.callback_manager import CallbackManager @@ -31,7 +40,12 @@ from erniebot_agent.chat_models.erniebot import BaseERNIEBot from erniebot_agent.file import File, FileManager from erniebot_agent.memory import Memory -from erniebot_agent.memory.messages import FunctionMessage, HumanMessage, Message, AIMessage +from erniebot_agent.memory.messages import ( + AIMessage, + FunctionMessage, + HumanMessage, + Message, +) from erniebot_agent.tools.base import BaseTool from erniebot_agent.tools.tool_manager import ToolManager diff --git a/erniebot-agent/tests/integration_tests/agents/test_functional_agent_stream.py b/erniebot-agent/tests/integration_tests/agents/test_functional_agent_stream.py index 1553a49c..6c035b75 100644 --- a/erniebot-agent/tests/integration_tests/agents/test_functional_agent_stream.py +++ b/erniebot-agent/tests/integration_tests/agents/test_functional_agent_stream.py @@ -5,7 +5,12 @@ from erniebot_agent.agents import FunctionAgent from erniebot_agent.chat_models import ERNIEBot from erniebot_agent.memory import WholeMemory -from erniebot_agent.memory.messages import AIMessage, AIMessageChunk, FunctionMessage, HumanMessage +from erniebot_agent.memory.messages import ( + AIMessage, + AIMessageChunk, + FunctionMessage, + HumanMessage, +) from erniebot_agent.tools.calculator_tool import CalculatorTool ONE_HIT_PROMPT = "1+4等于几?" From 72f97f6fc7dfcefc5285526b1721aa43c98d1daa Mon Sep 17 00:00:00 2001 From: xiabo Date: Sun, 28 Apr 2024 15:22:57 +0800 Subject: [PATCH 4/8] =?UTF-8?q?=E4=BD=BF=E7=94=A8`python=20-m=20mypy=20src?= =?UTF-8?q?`=E6=A3=80=E6=9F=A5=E5=92=8C=E4=BC=98=E5=8C=96=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/erniebot_agent/agents/agent.py | 30 ++++++++++++++----- .../erniebot_agent/agents/function_agent.py | 9 +++--- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/erniebot-agent/src/erniebot_agent/agents/agent.py b/erniebot-agent/src/erniebot_agent/agents/agent.py index d44f163a..c4428d1f 100644 --- a/erniebot-agent/src/erniebot_agent/agents/agent.py +++ b/erniebot-agent/src/erniebot_agent/agents/agent.py @@ -22,6 +22,7 @@ from erniebot_agent.agents.callback.handlers.base import CallbackHandler from erniebot_agent.agents.mixins import GradioMixin from erniebot_agent.agents.schema import ( + DEFAULT_FINISH_STEP, AgentResponse, AgentStep, LLMResponse, @@ -140,7 +141,7 @@ async def run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Agen @final async def run_stream( self, prompt: str, files: Optional[Sequence[File]] = None - ) -> AsyncIterator[AgentStep]: + ) -> AsyncIterator[Tuple[AgentStep, List[Message]]]: """Run the agent asynchronously. Args: @@ -154,14 +155,21 @@ async def run_stream( await self._ensure_managed_files(files) await self._callback_manager.on_run_start(agent=self, prompt=prompt) try: - async for step in self._run_stream(prompt, files): - yield step + async for step, msg in self._run_stream(prompt, files): + yield (step, msg) except BaseException as e: await self._callback_manager.on_run_error(agent=self, error=e) raise e else: - await self._callback_manager.on_run_end(agent=self, response=None) - return + await self._callback_manager.on_run_end( + agent=self, + response=AgentResponse( + text="Agent run stopped early.", + chat_history=self.memory.get_messages(), + steps=[step], + status="STOPPED", + ), + ) @final async def run_llm( @@ -284,8 +292,16 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age @abc.abstractmethod async def _run_stream( self, prompt: str, files: Optional[Sequence[File]] = None - ) -> AsyncIterator[AgentStep]: - raise NotImplementedError + ) -> AsyncIterator[Tuple[AgentStep, List[Message]]]: + """ + Abstract asynchronous generator method that should be implemented by subclasses. + This method should yield a sequence of (AgentStep, List[Message]) tuples based on the given + prompt and optionally accompanying files. + """ + if False: + # This conditional block is strictly for static type-checking purposes (e.g., mypy) and will not be executed. + only_for_mypy_type_check: Tuple[AgentStep, List[Message]] = (DEFAULT_FINISH_STEP, []) + yield only_for_mypy_type_check async def _run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse: for reserved_opt in ("stream", "system", "plugins"): diff --git a/erniebot-agent/src/erniebot_agent/agents/function_agent.py b/erniebot-agent/src/erniebot_agent/agents/function_agent.py index 867e85fd..fd4bde67 100644 --- a/erniebot-agent/src/erniebot_agent/agents/function_agent.py +++ b/erniebot-agent/src/erniebot_agent/agents/function_agent.py @@ -182,7 +182,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age return response async def _first_tool_step( - self, chat_history: List[Message], selected_tool: BaseTool = None + self, chat_history: List[Message], selected_tool: Optional[BaseTool] = None ) -> Tuple[AgentStep, List[Message]]: input_messages = self.memory.get_messages() + chat_history if selected_tool is None: @@ -254,7 +254,6 @@ async def _run_stream( _logger.warning(f"Selected tool [{tool.tool_name}] not work") is_finished = False - curr_step = None new_messages = [] end_step_msgs = [] while is_finished is False: @@ -274,17 +273,19 @@ async def _run_stream( curr_step = DEFAULT_FINISH_STEP yield curr_step, new_messages - if isinstance(curr_step, EndStep): + elif isinstance(curr_step, EndStep): is_finished = True end_step_msgs.extend(new_messages) yield curr_step, new_messages + else: + raise RuntimeError("Invalid step type") chat_history.extend(new_messages) self.memory.add_message(run_input) end_step_msg = AIMessage(content="".join([item.content for item in end_step_msgs])) self.memory.add_message(end_step_msg) - async def _schema_format(self, llm_resp, chat_history): + async def _schema_format(self, llm_resp, chat_history) -> Tuple[AgentStep, List[Message]]: """Convert the LLM response to the agent response schema. Args: llm_resp: The LLM response to convert. From 6ad47979c86b9ad61290830adb37e57f20ce0464 Mon Sep 17 00:00:00 2001 From: xiabo Date: Mon, 29 Apr 2024 10:00:36 +0800 Subject: [PATCH 5/8] =?UTF-8?q?=E4=BD=BF=E7=94=A8`make=20lint`=E6=A3=80?= =?UTF-8?q?=E6=9F=A5=E4=BC=98=E5=8C=96=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- erniebot-agent/src/erniebot_agent/agents/agent.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/erniebot-agent/src/erniebot_agent/agents/agent.py b/erniebot-agent/src/erniebot_agent/agents/agent.py index c4428d1f..643ac53d 100644 --- a/erniebot-agent/src/erniebot_agent/agents/agent.py +++ b/erniebot-agent/src/erniebot_agent/agents/agent.py @@ -299,7 +299,8 @@ async def _run_stream( prompt and optionally accompanying files. """ if False: - # This conditional block is strictly for static type-checking purposes (e.g., mypy) and will not be executed. + # This conditional block is strictly for static type-checking purposes (e.g., mypy) + # and will not be executed. only_for_mypy_type_check: Tuple[AgentStep, List[Message]] = (DEFAULT_FINISH_STEP, []) yield only_for_mypy_type_check From 2d3c851c271fd93391333ba22c02cc27d3a90f3e Mon Sep 17 00:00:00 2001 From: xiabo Date: Tue, 30 Apr 2024 11:17:14 +0800 Subject: [PATCH 6/8] =?UTF-8?q?=E6=A0=B9=E6=8D=AEreview=E6=84=8F=E8=A7=81?= =?UTF-8?q?=E8=BF=9B=E8=A1=8C=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/erniebot_agent/agents/agent.py | 30 +++++++++++++++---- .../erniebot_agent/agents/function_agent.py | 20 +++++++------ .../agents/test_functional_agent_stream.py | 28 +++++++---------- 3 files changed, 46 insertions(+), 32 deletions(-) diff --git a/erniebot-agent/src/erniebot_agent/agents/agent.py b/erniebot-agent/src/erniebot_agent/agents/agent.py index 643ac53d..33abf6ad 100644 --- a/erniebot-agent/src/erniebot_agent/agents/agent.py +++ b/erniebot-agent/src/erniebot_agent/agents/agent.py @@ -142,14 +142,14 @@ async def run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Agen async def run_stream( self, prompt: str, files: Optional[Sequence[File]] = None ) -> AsyncIterator[Tuple[AgentStep, List[Message]]]: - """Run the agent asynchronously. + """Run the agent asynchronously, returning an async iterator of responses. Args: prompt: A natural language text describing the task that the agent should perform. files: A list of files that the agent can use to perform the task. Returns: - Response AgentStep from the agent. + Iterator of responses from the agent. """ if files: await self._ensure_managed_files(files) @@ -164,7 +164,7 @@ async def run_stream( await self._callback_manager.on_run_end( agent=self, response=AgentResponse( - text="Agent run stopped early.", + text="Agent run stopped.", chat_history=self.memory.get_messages(), steps=[step], status="STOPPED", @@ -177,7 +177,7 @@ async def run_llm( messages: List[Message], **llm_opts: Any, ) -> LLMResponse: - """Run the LLM asynchronously. + """Run the LLM asynchronously, returning final response. Args: messages: The input messages. @@ -202,14 +202,14 @@ async def run_llm_stream( messages: List[Message], **llm_opts: Any, ) -> AsyncIterator[LLMResponse]: - """Run the LLM asynchronously. + """Run the LLM asynchronously, returning an async iterator of responses Args: messages: The input messages. llm_opts: Options to pass to the LLM. Returns: - Response from the LLM. + Iterator of responses from the LLM. """ llm_resp = None await self._callback_manager.on_llm_start(agent=self, llm=self.llm, messages=messages) @@ -305,6 +305,15 @@ async def _run_stream( yield only_for_mypy_type_check async def _run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse: + """Run the LLM with the given messages and options. + + Args: + messages: The input messages. + opts: Options to pass to the LLM. + + Returns: + Response from the LLM. + """ for reserved_opt in ("stream", "system", "plugins"): if reserved_opt in opts: raise TypeError(f"`{reserved_opt}` should not be set.") @@ -325,6 +334,15 @@ async def _run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse: return LLMResponse(message=llm_ret) async def _run_llm_stream(self, messages: List[Message], **opts: Any) -> AsyncIterator[LLMResponse]: + """Run the LLM, yielding an async iterator of responses. + + Args: + messages: The input messages. + opts: Options to pass to the LLM. + + Returns: + Async iterator of responses from the LLM. + """ for reserved_opt in ("stream", "system", "plugins"): if reserved_opt in opts: raise TypeError(f"`{reserved_opt}` should not be set.") diff --git a/erniebot-agent/src/erniebot_agent/agents/function_agent.py b/erniebot-agent/src/erniebot_agent/agents/function_agent.py index fd4bde67..fc0badab 100644 --- a/erniebot-agent/src/erniebot_agent/agents/function_agent.py +++ b/erniebot-agent/src/erniebot_agent/agents/function_agent.py @@ -150,7 +150,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age chat_history.append(run_input) for tool in self._first_tools: - curr_step, new_messages = await self._first_tool_step(chat_history, selected_tool=tool) + curr_step, new_messages = await self._call_first_tools(chat_history, selected_tool=tool) if not isinstance(curr_step, EndStep): chat_history.extend(new_messages) num_steps_taken += 1 @@ -181,13 +181,13 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age response = self._create_stopped_response(chat_history, steps_taken) return response - async def _first_tool_step( + async def _call_first_tools( self, chat_history: List[Message], selected_tool: Optional[BaseTool] = None ) -> Tuple[AgentStep, List[Message]]: input_messages = self.memory.get_messages() + chat_history if selected_tool is None: llm_resp = await self.run_llm(messages=input_messages) - return await self._schema_format(llm_resp, chat_history) + return await self._process_step(llm_resp, chat_history) tool_choice = {"type": "function", "function": {"name": selected_tool.tool_name}} llm_resp = await self.run_llm( @@ -195,7 +195,7 @@ async def _first_tool_step( functions=[selected_tool.function_call_schema()], # only regist one tool tool_choice=tool_choice, ) - return await self._schema_format(llm_resp, chat_history) + return await self._process_step(llm_resp, chat_history) async def _step(self, chat_history: List[Message]) -> Tuple[AgentStep, List[Message]]: """Run a step of the agent. @@ -206,7 +206,7 @@ async def _step(self, chat_history: List[Message]) -> Tuple[AgentStep, List[Mess """ input_messages = self.memory.get_messages() + chat_history llm_resp = await self.run_llm(messages=input_messages) - return await self._schema_format(llm_resp, chat_history) + return await self._process_step(llm_resp, chat_history) async def _step_stream( self, chat_history: List[Message] @@ -219,7 +219,7 @@ async def _step_stream( """ input_messages = self.memory.get_messages() + chat_history async for llm_resp in self.run_llm_stream(messages=input_messages): - yield await self._schema_format(llm_resp, chat_history) + yield await self._process_step(llm_resp, chat_history) async def _run_stream( self, prompt: str, files: Optional[Sequence[File]] = None @@ -244,7 +244,7 @@ async def _run_stream( chat_history.append(run_input) for tool in self._first_tools: - curr_step, new_messages = await self._first_tool_step(chat_history, selected_tool=tool) + curr_step, new_messages = await self._call_first_tools(chat_history, selected_tool=tool) if not isinstance(curr_step, EndStep): chat_history.extend(new_messages) num_steps_taken += 1 @@ -285,8 +285,8 @@ async def _run_stream( end_step_msg = AIMessage(content="".join([item.content for item in end_step_msgs])) self.memory.add_message(end_step_msg) - async def _schema_format(self, llm_resp, chat_history) -> Tuple[AgentStep, List[Message]]: - """Convert the LLM response to the agent response schema. + async def _process_step(self, llm_resp, chat_history) -> Tuple[AgentStep, List[Message]]: + """Process and execute a step of the agent from LLM response. Args: llm_resp: The LLM response to convert. chat_history: The chat history to provide to the agent. @@ -296,6 +296,7 @@ async def _schema_format(self, llm_resp, chat_history) -> Tuple[AgentStep, List[ new_messages: List[Message] = [] output_message = llm_resp.message # AIMessage new_messages.append(output_message) + # handle function call if output_message.function_call is not None: tool_name = output_message.function_call["name"] tool_args = output_message.function_call["arguments"] @@ -310,6 +311,7 @@ async def _schema_format(self, llm_resp, chat_history) -> Tuple[AgentStep, List[ ), new_messages, ) + # handle plugin info with input/output files elif output_message.plugin_info is not None: file_manager = self.get_file_manager() return ( diff --git a/erniebot-agent/tests/integration_tests/agents/test_functional_agent_stream.py b/erniebot-agent/tests/integration_tests/agents/test_functional_agent_stream.py index 6c035b75..46ec4463 100644 --- a/erniebot-agent/tests/integration_tests/agents/test_functional_agent_stream.py +++ b/erniebot-agent/tests/integration_tests/agents/test_functional_agent_stream.py @@ -3,6 +3,7 @@ import pytest from erniebot_agent.agents import FunctionAgent +from erniebot_agent.agents.schema import EndStep from erniebot_agent.chat_models import ERNIEBot from erniebot_agent.memory import WholeMemory from erniebot_agent.memory.messages import ( @@ -16,9 +17,6 @@ ONE_HIT_PROMPT = "1+4等于几?" NO_HIT_PROMPT = "深圳今天天气怎么样?" -# cd ERNIE-SDK/erniebot-agent/tests -# EB_AGENT_ACCESS_TOKEN= pytest integration_tests/agents/test_functional_agent_stream.py -s - @pytest.fixture(scope="module") def llm(): @@ -55,9 +53,9 @@ async def test_function_agent_run_one_hit(llm, tool, memory): assert json.loads(run_logs[0][1][1].content) == {"formula_result": 5} assert isinstance(agent.memory.get_messages()[1], AIMessage) - steps = [step for step, msgs in run_logs if step.__class__.__name__ != "EndStep"] - assert len(steps) == 1 - assert steps[0].info["tool_name"] == tool.tool_name + tool_steps = [step for step, msgs in run_logs if not isinstance(step, EndStep)] + assert len(tool_steps) == 1 + assert tool_steps[0].info["tool_name"] == tool.tool_name @pytest.mark.asyncio @@ -73,14 +71,12 @@ async def test_function_agent_run_no_hit(llm, tool, memory): assert isinstance(agent.memory.get_messages()[0], HumanMessage) assert agent.memory.get_messages()[0].content == prompt assert isinstance(run_logs[0][1][0], AIMessage) - end_step_msg = "".join( - [msg[0].content for step, msg in run_logs if step.__class__.__name__ == "EndStep"] - ) + end_step_msg = "".join([msg[0].content for step, msg in run_logs if isinstance(step, EndStep)]) assert end_step_msg == agent.memory.get_messages()[1].content - steps = [step for step, msgs in run_logs if step.__class__.__name__ != "EndStep"] - assert len(steps) == 0 - + tool_steps = [step for step, msgs in run_logs if not isinstance(step, EndStep)] + assert len(tool_steps) == 0 + @pytest.mark.asyncio @pytest.mark.parametrize("prompt", [ONE_HIT_PROMPT, NO_HIT_PROMPT]) @@ -95,10 +91,8 @@ async def test_function_agent_run_no_tool(llm, memory, prompt): assert isinstance(agent.memory.get_messages()[0], HumanMessage) assert agent.memory.get_messages()[0].content == prompt assert isinstance(run_logs[0][1][0], AIMessage) - end_step_msg = "".join( - [msg[0].content for step, msg in run_logs if step.__class__.__name__ == "EndStep"] - ) + end_step_msg = "".join([msg[0].content for step, msg in run_logs if isinstance(step, EndStep)]) assert end_step_msg == agent.memory.get_messages()[1].content - steps = [step for step, msgs in run_logs if step.__class__.__name__ != "EndStep"] - assert len(steps) == 0 + tool_steps = [step for step, msgs in run_logs if not isinstance(step, EndStep)] + assert len(tool_steps) == 0 \ No newline at end of file From c95bde738451410484ffbb6ca734cd041aefeeba Mon Sep 17 00:00:00 2001 From: xiabo Date: Wed, 1 May 2024 20:19:04 +0800 Subject: [PATCH 7/8] =?UTF-8?q?=E4=BD=BF=E7=94=A8`python=20-m=20black=20--?= =?UTF-8?q?check`=E6=A3=80=E6=9F=A5=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../integration_tests/agents/test_functional_agent_stream.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/erniebot-agent/tests/integration_tests/agents/test_functional_agent_stream.py b/erniebot-agent/tests/integration_tests/agents/test_functional_agent_stream.py index 46ec4463..7d9d31a8 100644 --- a/erniebot-agent/tests/integration_tests/agents/test_functional_agent_stream.py +++ b/erniebot-agent/tests/integration_tests/agents/test_functional_agent_stream.py @@ -76,7 +76,7 @@ async def test_function_agent_run_no_hit(llm, tool, memory): tool_steps = [step for step, msgs in run_logs if not isinstance(step, EndStep)] assert len(tool_steps) == 0 - + @pytest.mark.asyncio @pytest.mark.parametrize("prompt", [ONE_HIT_PROMPT, NO_HIT_PROMPT]) @@ -95,4 +95,4 @@ async def test_function_agent_run_no_tool(llm, memory, prompt): assert end_step_msg == agent.memory.get_messages()[1].content tool_steps = [step for step, msgs in run_logs if not isinstance(step, EndStep)] - assert len(tool_steps) == 0 \ No newline at end of file + assert len(tool_steps) == 0 From 8c520b239724de2ea6c13e4e38bba499016410ac Mon Sep 17 00:00:00 2001 From: xiabo Date: Wed, 1 May 2024 20:54:04 +0800 Subject: [PATCH 8/8] =?UTF-8?q?=E5=88=A4=E6=96=ADtyping.TYPE=5FCHECKING?= =?UTF-8?q?=E6=9D=A5=E5=A4=84=E7=90=86=E7=BC=96=E8=AF=91=E5=99=A8=E6=A3=80?= =?UTF-8?q?=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- erniebot-agent/src/erniebot_agent/agents/agent.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/erniebot-agent/src/erniebot_agent/agents/agent.py b/erniebot-agent/src/erniebot_agent/agents/agent.py index 33abf6ad..d04a9c96 100644 --- a/erniebot-agent/src/erniebot_agent/agents/agent.py +++ b/erniebot-agent/src/erniebot_agent/agents/agent.py @@ -2,6 +2,7 @@ import json import logging from typing import ( + TYPE_CHECKING, Any, AsyncIterator, Dict, @@ -298,7 +299,8 @@ async def _run_stream( This method should yield a sequence of (AgentStep, List[Message]) tuples based on the given prompt and optionally accompanying files. """ - if False: + if TYPE_CHECKING: + # HACK # This conditional block is strictly for static type-checking purposes (e.g., mypy) # and will not be executed. only_for_mypy_type_check: Tuple[AgentStep, List[Message]] = (DEFAULT_FINISH_STEP, [])