diff --git a/erniebot-agent/src/erniebot_agent/agents/agent.py b/erniebot-agent/src/erniebot_agent/agents/agent.py index 527c7278..0b3af77c 100644 --- a/erniebot-agent/src/erniebot_agent/agents/agent.py +++ b/erniebot-agent/src/erniebot_agent/agents/agent.py @@ -133,7 +133,9 @@ async def run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Agen return agent_resp @final - async def run_stream(self, prompt: str, files: Optional[Sequence[File]] = None) -> AsyncIterator[AgentStep]: + async def run_stream( + self, prompt: str, files: Optional[Sequence[File]] = None + ) -> AsyncIterator[AgentStep]: """Run the agent asynchronously. Args: @@ -180,7 +182,6 @@ async def run_llm( else: await self._callback_manager.on_llm_end(agent=self, llm=self.llm, response=llm_resp) return llm_resp - @final async def run_llm_stream( @@ -209,7 +210,6 @@ async def run_llm_stream( else: await self._callback_manager.on_llm_end(agent=self, llm=self.llm, response=llm_resp) return - @final async def run_tool(self, tool_name: str, tool_args: str) -> ToolResponse: @@ -277,9 +277,11 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age raise NotImplementedError @abc.abstractmethod - async def _run_stream(self, prompt: str, files: Optional[Sequence[File]] = None) -> AsyncIterator[AgentStep]: + async def _run_stream( + self, prompt: str, files: Optional[Sequence[File]] = None + ) -> AsyncIterator[AgentStep]: raise NotImplementedError - + async def _run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse: for reserved_opt in ("stream", "system", "plugins"): if reserved_opt in opts: @@ -299,7 +301,6 @@ async def _run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse: opts["plugins"] = self._plugins llm_ret = await self.llm.chat(messages, stream=False, functions=functions, **opts) return LLMResponse(message=llm_ret) - async def _run_llm_stream(self, messages: List[Message], **opts: Any) -> AsyncIterator[LLMResponse]: for reserved_opt in ("stream", "system", "plugins"): @@ -322,7 +323,6 @@ async def _run_llm_stream(self, messages: List[Message], **opts: Any) -> AsyncIt async for msg in llm_ret: yield LLMResponse(message=msg) - async def _run_tool(self, tool: BaseTool, tool_args: str) -> ToolResponse: parsed_tool_args = self._parse_tool_args(tool_args) file_manager = self.get_file_manager() diff --git a/erniebot-agent/src/erniebot_agent/agents/function_agent.py b/erniebot-agent/src/erniebot_agent/agents/function_agent.py index 378e33c8..5f231478 100644 --- a/erniebot-agent/src/erniebot_agent/agents/function_agent.py +++ b/erniebot-agent/src/erniebot_agent/agents/function_agent.py @@ -166,7 +166,6 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age num_steps_taken += 1 response = self._create_stopped_response(chat_history, steps_taken) return response - async def _first_tool_step( self, chat_history: List[Message], selected_tool: BaseTool = None @@ -184,7 +183,6 @@ async def _first_tool_step( ) return await self._schema_format(llm_resp, chat_history) - async def _step(self, chat_history: List[Message]) -> Tuple[AgentStep, List[Message]]: """Run a step of the agent. Args: @@ -196,8 +194,9 @@ async def _step(self, chat_history: List[Message]) -> Tuple[AgentStep, List[Mess llm_resp = await self.run_llm(messages=input_messages) return await self._schema_format(llm_resp, chat_history) - - async def _step_stream(self, chat_history: List[Message]) -> AsyncIterator[Tuple[AgentStep, List[Message]]]: + async def _step_stream( + self, chat_history: List[Message] + ) -> AsyncIterator[Tuple[AgentStep, List[Message]]]: """Run a step of the agent in streaming mode. Args: chat_history: The chat history to provide to the agent. @@ -208,8 +207,9 @@ async def _step_stream(self, chat_history: List[Message]) -> AsyncIterator[Tuple async for llm_resp in self.run_llm_stream(messages=input_messages): yield await self._schema_format(llm_resp, chat_history) - - async def _run_stream(self, prompt: str, files: Optional[Sequence[File]] = None) -> AsyncIterator[Tuple[AgentStep, List[Message]]]: + async def _run_stream( + self, prompt: str, files: Optional[Sequence[File]] = None + ) -> AsyncIterator[Tuple[AgentStep, List[Message]]]: """Run the agent with the given prompt and files in streaming mode. Args: prompt: The prompt for the agent to run. @@ -259,7 +259,7 @@ async def _run_stream(self, prompt: str, files: Optional[Sequence[File]] = None) # 此处为调用了Plugin之后直接结束的Plugin curr_step = DEFAULT_FINISH_STEP yield curr_step, new_messages - + if isinstance(curr_step, EndStep): is_finished = True end_step_msgs.extend(new_messages) @@ -270,7 +270,6 @@ async def _run_stream(self, prompt: str, files: Optional[Sequence[File]] = None) end_step_msg = AIMessage(content="".join([item.content for item in end_step_msgs])) self.memory.add_message(end_step_msg) - async def _schema_format(self, llm_resp, chat_history): """Convert the LLM response to the agent response schema. Args: diff --git a/erniebot-agent/tests/integration_tests/agents/test_functional_agent_stream.py b/erniebot-agent/tests/integration_tests/agents/test_functional_agent_stream.py index 5bfd5f44..1553a49c 100644 --- a/erniebot-agent/tests/integration_tests/agents/test_functional_agent_stream.py +++ b/erniebot-agent/tests/integration_tests/agents/test_functional_agent_stream.py @@ -14,6 +14,7 @@ # cd ERNIE-SDK/erniebot-agent/tests # EB_AGENT_ACCESS_TOKEN= pytest integration_tests/agents/test_functional_agent_stream.py -s + @pytest.fixture(scope="module") def llm(): return ERNIEBot(model="ernie-3.5") @@ -36,7 +37,7 @@ async def test_function_agent_run_one_hit(llm, tool, memory): run_logs = [] async for step, msgs in agent.run_stream(prompt): - run_logs.append((step, msgs)) + run_logs.append((step, msgs)) assert len(agent.memory.get_messages()) == 2 assert isinstance(agent.memory.get_messages()[0], HumanMessage) @@ -48,11 +49,12 @@ async def test_function_agent_run_one_hit(llm, tool, memory): assert run_logs[0][1][1].name == run_logs[0][1][0].function_call["name"] assert json.loads(run_logs[0][1][1].content) == {"formula_result": 5} assert isinstance(agent.memory.get_messages()[1], AIMessage) - + steps = [step for step, msgs in run_logs if step.__class__.__name__ != "EndStep"] assert len(steps) == 1 assert steps[0].info["tool_name"] == tool.tool_name + @pytest.mark.asyncio async def test_function_agent_run_no_hit(llm, tool, memory): agent = FunctionAgent(llm=llm, tools=[tool], memory=memory) @@ -66,7 +68,9 @@ async def test_function_agent_run_no_hit(llm, tool, memory): assert isinstance(agent.memory.get_messages()[0], HumanMessage) assert agent.memory.get_messages()[0].content == prompt assert isinstance(run_logs[0][1][0], AIMessage) - end_step_msg = "".join([msg[0].content for step, msg in run_logs if step.__class__.__name__ == "EndStep"]) + end_step_msg = "".join( + [msg[0].content for step, msg in run_logs if step.__class__.__name__ == "EndStep"] + ) assert end_step_msg == agent.memory.get_messages()[1].content steps = [step for step, msgs in run_logs if step.__class__.__name__ != "EndStep"] @@ -86,12 +90,10 @@ async def test_function_agent_run_no_tool(llm, memory, prompt): assert isinstance(agent.memory.get_messages()[0], HumanMessage) assert agent.memory.get_messages()[0].content == prompt assert isinstance(run_logs[0][1][0], AIMessage) - end_step_msg = "".join([msg[0].content for step, msg in run_logs if step.__class__.__name__ == "EndStep"]) + end_step_msg = "".join( + [msg[0].content for step, msg in run_logs if step.__class__.__name__ == "EndStep"] + ) assert end_step_msg == agent.memory.get_messages()[1].content steps = [step for step, msgs in run_logs if step.__class__.__name__ != "EndStep"] assert len(steps) == 0 - - - -