From 2d3c851c271fd93391333ba22c02cc27d3a90f3e Mon Sep 17 00:00:00 2001 From: xiabo Date: Tue, 30 Apr 2024 11:17:14 +0800 Subject: [PATCH] =?UTF-8?q?=E6=A0=B9=E6=8D=AEreview=E6=84=8F=E8=A7=81?= =?UTF-8?q?=E8=BF=9B=E8=A1=8C=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/erniebot_agent/agents/agent.py | 30 +++++++++++++++---- .../erniebot_agent/agents/function_agent.py | 20 +++++++------ .../agents/test_functional_agent_stream.py | 28 +++++++---------- 3 files changed, 46 insertions(+), 32 deletions(-) diff --git a/erniebot-agent/src/erniebot_agent/agents/agent.py b/erniebot-agent/src/erniebot_agent/agents/agent.py index 643ac53d..33abf6ad 100644 --- a/erniebot-agent/src/erniebot_agent/agents/agent.py +++ b/erniebot-agent/src/erniebot_agent/agents/agent.py @@ -142,14 +142,14 @@ async def run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Agen async def run_stream( self, prompt: str, files: Optional[Sequence[File]] = None ) -> AsyncIterator[Tuple[AgentStep, List[Message]]]: - """Run the agent asynchronously. + """Run the agent asynchronously, returning an async iterator of responses. Args: prompt: A natural language text describing the task that the agent should perform. files: A list of files that the agent can use to perform the task. Returns: - Response AgentStep from the agent. + Iterator of responses from the agent. """ if files: await self._ensure_managed_files(files) @@ -164,7 +164,7 @@ async def run_stream( await self._callback_manager.on_run_end( agent=self, response=AgentResponse( - text="Agent run stopped early.", + text="Agent run stopped.", chat_history=self.memory.get_messages(), steps=[step], status="STOPPED", @@ -177,7 +177,7 @@ async def run_llm( messages: List[Message], **llm_opts: Any, ) -> LLMResponse: - """Run the LLM asynchronously. + """Run the LLM asynchronously, returning final response. Args: messages: The input messages. @@ -202,14 +202,14 @@ async def run_llm_stream( messages: List[Message], **llm_opts: Any, ) -> AsyncIterator[LLMResponse]: - """Run the LLM asynchronously. + """Run the LLM asynchronously, returning an async iterator of responses Args: messages: The input messages. llm_opts: Options to pass to the LLM. Returns: - Response from the LLM. + Iterator of responses from the LLM. """ llm_resp = None await self._callback_manager.on_llm_start(agent=self, llm=self.llm, messages=messages) @@ -305,6 +305,15 @@ async def _run_stream( yield only_for_mypy_type_check async def _run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse: + """Run the LLM with the given messages and options. + + Args: + messages: The input messages. + opts: Options to pass to the LLM. + + Returns: + Response from the LLM. + """ for reserved_opt in ("stream", "system", "plugins"): if reserved_opt in opts: raise TypeError(f"`{reserved_opt}` should not be set.") @@ -325,6 +334,15 @@ async def _run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse: return LLMResponse(message=llm_ret) async def _run_llm_stream(self, messages: List[Message], **opts: Any) -> AsyncIterator[LLMResponse]: + """Run the LLM, yielding an async iterator of responses. + + Args: + messages: The input messages. + opts: Options to pass to the LLM. + + Returns: + Async iterator of responses from the LLM. + """ for reserved_opt in ("stream", "system", "plugins"): if reserved_opt in opts: raise TypeError(f"`{reserved_opt}` should not be set.") diff --git a/erniebot-agent/src/erniebot_agent/agents/function_agent.py b/erniebot-agent/src/erniebot_agent/agents/function_agent.py index fd4bde67..fc0badab 100644 --- a/erniebot-agent/src/erniebot_agent/agents/function_agent.py +++ b/erniebot-agent/src/erniebot_agent/agents/function_agent.py @@ -150,7 +150,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age chat_history.append(run_input) for tool in self._first_tools: - curr_step, new_messages = await self._first_tool_step(chat_history, selected_tool=tool) + curr_step, new_messages = await self._call_first_tools(chat_history, selected_tool=tool) if not isinstance(curr_step, EndStep): chat_history.extend(new_messages) num_steps_taken += 1 @@ -181,13 +181,13 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age response = self._create_stopped_response(chat_history, steps_taken) return response - async def _first_tool_step( + async def _call_first_tools( self, chat_history: List[Message], selected_tool: Optional[BaseTool] = None ) -> Tuple[AgentStep, List[Message]]: input_messages = self.memory.get_messages() + chat_history if selected_tool is None: llm_resp = await self.run_llm(messages=input_messages) - return await self._schema_format(llm_resp, chat_history) + return await self._process_step(llm_resp, chat_history) tool_choice = {"type": "function", "function": {"name": selected_tool.tool_name}} llm_resp = await self.run_llm( @@ -195,7 +195,7 @@ async def _first_tool_step( functions=[selected_tool.function_call_schema()], # only regist one tool tool_choice=tool_choice, ) - return await self._schema_format(llm_resp, chat_history) + return await self._process_step(llm_resp, chat_history) async def _step(self, chat_history: List[Message]) -> Tuple[AgentStep, List[Message]]: """Run a step of the agent. @@ -206,7 +206,7 @@ async def _step(self, chat_history: List[Message]) -> Tuple[AgentStep, List[Mess """ input_messages = self.memory.get_messages() + chat_history llm_resp = await self.run_llm(messages=input_messages) - return await self._schema_format(llm_resp, chat_history) + return await self._process_step(llm_resp, chat_history) async def _step_stream( self, chat_history: List[Message] @@ -219,7 +219,7 @@ async def _step_stream( """ input_messages = self.memory.get_messages() + chat_history async for llm_resp in self.run_llm_stream(messages=input_messages): - yield await self._schema_format(llm_resp, chat_history) + yield await self._process_step(llm_resp, chat_history) async def _run_stream( self, prompt: str, files: Optional[Sequence[File]] = None @@ -244,7 +244,7 @@ async def _run_stream( chat_history.append(run_input) for tool in self._first_tools: - curr_step, new_messages = await self._first_tool_step(chat_history, selected_tool=tool) + curr_step, new_messages = await self._call_first_tools(chat_history, selected_tool=tool) if not isinstance(curr_step, EndStep): chat_history.extend(new_messages) num_steps_taken += 1 @@ -285,8 +285,8 @@ async def _run_stream( end_step_msg = AIMessage(content="".join([item.content for item in end_step_msgs])) self.memory.add_message(end_step_msg) - async def _schema_format(self, llm_resp, chat_history) -> Tuple[AgentStep, List[Message]]: - """Convert the LLM response to the agent response schema. + async def _process_step(self, llm_resp, chat_history) -> Tuple[AgentStep, List[Message]]: + """Process and execute a step of the agent from LLM response. Args: llm_resp: The LLM response to convert. chat_history: The chat history to provide to the agent. @@ -296,6 +296,7 @@ async def _schema_format(self, llm_resp, chat_history) -> Tuple[AgentStep, List[ new_messages: List[Message] = [] output_message = llm_resp.message # AIMessage new_messages.append(output_message) + # handle function call if output_message.function_call is not None: tool_name = output_message.function_call["name"] tool_args = output_message.function_call["arguments"] @@ -310,6 +311,7 @@ async def _schema_format(self, llm_resp, chat_history) -> Tuple[AgentStep, List[ ), new_messages, ) + # handle plugin info with input/output files elif output_message.plugin_info is not None: file_manager = self.get_file_manager() return ( diff --git a/erniebot-agent/tests/integration_tests/agents/test_functional_agent_stream.py b/erniebot-agent/tests/integration_tests/agents/test_functional_agent_stream.py index 6c035b75..46ec4463 100644 --- a/erniebot-agent/tests/integration_tests/agents/test_functional_agent_stream.py +++ b/erniebot-agent/tests/integration_tests/agents/test_functional_agent_stream.py @@ -3,6 +3,7 @@ import pytest from erniebot_agent.agents import FunctionAgent +from erniebot_agent.agents.schema import EndStep from erniebot_agent.chat_models import ERNIEBot from erniebot_agent.memory import WholeMemory from erniebot_agent.memory.messages import ( @@ -16,9 +17,6 @@ ONE_HIT_PROMPT = "1+4等于几?" NO_HIT_PROMPT = "深圳今天天气怎么样?" -# cd ERNIE-SDK/erniebot-agent/tests -# EB_AGENT_ACCESS_TOKEN= pytest integration_tests/agents/test_functional_agent_stream.py -s - @pytest.fixture(scope="module") def llm(): @@ -55,9 +53,9 @@ async def test_function_agent_run_one_hit(llm, tool, memory): assert json.loads(run_logs[0][1][1].content) == {"formula_result": 5} assert isinstance(agent.memory.get_messages()[1], AIMessage) - steps = [step for step, msgs in run_logs if step.__class__.__name__ != "EndStep"] - assert len(steps) == 1 - assert steps[0].info["tool_name"] == tool.tool_name + tool_steps = [step for step, msgs in run_logs if not isinstance(step, EndStep)] + assert len(tool_steps) == 1 + assert tool_steps[0].info["tool_name"] == tool.tool_name @pytest.mark.asyncio @@ -73,14 +71,12 @@ async def test_function_agent_run_no_hit(llm, tool, memory): assert isinstance(agent.memory.get_messages()[0], HumanMessage) assert agent.memory.get_messages()[0].content == prompt assert isinstance(run_logs[0][1][0], AIMessage) - end_step_msg = "".join( - [msg[0].content for step, msg in run_logs if step.__class__.__name__ == "EndStep"] - ) + end_step_msg = "".join([msg[0].content for step, msg in run_logs if isinstance(step, EndStep)]) assert end_step_msg == agent.memory.get_messages()[1].content - steps = [step for step, msgs in run_logs if step.__class__.__name__ != "EndStep"] - assert len(steps) == 0 - + tool_steps = [step for step, msgs in run_logs if not isinstance(step, EndStep)] + assert len(tool_steps) == 0 + @pytest.mark.asyncio @pytest.mark.parametrize("prompt", [ONE_HIT_PROMPT, NO_HIT_PROMPT]) @@ -95,10 +91,8 @@ async def test_function_agent_run_no_tool(llm, memory, prompt): assert isinstance(agent.memory.get_messages()[0], HumanMessage) assert agent.memory.get_messages()[0].content == prompt assert isinstance(run_logs[0][1][0], AIMessage) - end_step_msg = "".join( - [msg[0].content for step, msg in run_logs if step.__class__.__name__ == "EndStep"] - ) + end_step_msg = "".join([msg[0].content for step, msg in run_logs if isinstance(step, EndStep)]) assert end_step_msg == agent.memory.get_messages()[1].content - steps = [step for step, msgs in run_logs if step.__class__.__name__ != "EndStep"] - assert len(steps) == 0 + tool_steps = [step for step, msgs in run_logs if not isinstance(step, EndStep)] + assert len(tool_steps) == 0 \ No newline at end of file