Skip to content

Commit

Permalink
使用make format-check进行代码优化
Browse files Browse the repository at this point in the history
  • Loading branch information
xiabo0816 committed Apr 26, 2024
1 parent d3cfef2 commit 3fb74a1
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 23 deletions.
14 changes: 7 additions & 7 deletions erniebot-agent/src/erniebot_agent/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ async def run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Agen
return agent_resp

@final
async def run_stream(self, prompt: str, files: Optional[Sequence[File]] = None) -> AsyncIterator[AgentStep]:
async def run_stream(
self, prompt: str, files: Optional[Sequence[File]] = None
) -> AsyncIterator[AgentStep]:
"""Run the agent asynchronously.
Args:
Expand Down Expand Up @@ -180,7 +182,6 @@ async def run_llm(
else:
await self._callback_manager.on_llm_end(agent=self, llm=self.llm, response=llm_resp)
return llm_resp


@final
async def run_llm_stream(
Expand Down Expand Up @@ -209,7 +210,6 @@ async def run_llm_stream(
else:
await self._callback_manager.on_llm_end(agent=self, llm=self.llm, response=llm_resp)
return


@final
async def run_tool(self, tool_name: str, tool_args: str) -> ToolResponse:
Expand Down Expand Up @@ -277,9 +277,11 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
raise NotImplementedError

@abc.abstractmethod
async def _run_stream(self, prompt: str, files: Optional[Sequence[File]] = None) -> AsyncIterator[AgentStep]:
async def _run_stream(
self, prompt: str, files: Optional[Sequence[File]] = None
) -> AsyncIterator[AgentStep]:
raise NotImplementedError

async def _run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse:
for reserved_opt in ("stream", "system", "plugins"):
if reserved_opt in opts:
Expand All @@ -299,7 +301,6 @@ async def _run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse:
opts["plugins"] = self._plugins
llm_ret = await self.llm.chat(messages, stream=False, functions=functions, **opts)
return LLMResponse(message=llm_ret)


async def _run_llm_stream(self, messages: List[Message], **opts: Any) -> AsyncIterator[LLMResponse]:
for reserved_opt in ("stream", "system", "plugins"):
Expand All @@ -322,7 +323,6 @@ async def _run_llm_stream(self, messages: List[Message], **opts: Any) -> AsyncIt
async for msg in llm_ret:
yield LLMResponse(message=msg)


async def _run_tool(self, tool: BaseTool, tool_args: str) -> ToolResponse:
parsed_tool_args = self._parse_tool_args(tool_args)
file_manager = self.get_file_manager()
Expand Down
15 changes: 7 additions & 8 deletions erniebot-agent/src/erniebot_agent/agents/function_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
num_steps_taken += 1
response = self._create_stopped_response(chat_history, steps_taken)
return response


async def _first_tool_step(
self, chat_history: List[Message], selected_tool: BaseTool = None
Expand All @@ -184,7 +183,6 @@ async def _first_tool_step(
)
return await self._schema_format(llm_resp, chat_history)


async def _step(self, chat_history: List[Message]) -> Tuple[AgentStep, List[Message]]:
"""Run a step of the agent.
Args:
Expand All @@ -196,8 +194,9 @@ async def _step(self, chat_history: List[Message]) -> Tuple[AgentStep, List[Mess
llm_resp = await self.run_llm(messages=input_messages)
return await self._schema_format(llm_resp, chat_history)


async def _step_stream(self, chat_history: List[Message]) -> AsyncIterator[Tuple[AgentStep, List[Message]]]:
async def _step_stream(
self, chat_history: List[Message]
) -> AsyncIterator[Tuple[AgentStep, List[Message]]]:
"""Run a step of the agent in streaming mode.
Args:
chat_history: The chat history to provide to the agent.
Expand All @@ -208,8 +207,9 @@ async def _step_stream(self, chat_history: List[Message]) -> AsyncIterator[Tuple
async for llm_resp in self.run_llm_stream(messages=input_messages):
yield await self._schema_format(llm_resp, chat_history)


async def _run_stream(self, prompt: str, files: Optional[Sequence[File]] = None) -> AsyncIterator[Tuple[AgentStep, List[Message]]]:
async def _run_stream(
self, prompt: str, files: Optional[Sequence[File]] = None
) -> AsyncIterator[Tuple[AgentStep, List[Message]]]:
"""Run the agent with the given prompt and files in streaming mode.
Args:
prompt: The prompt for the agent to run.
Expand Down Expand Up @@ -259,7 +259,7 @@ async def _run_stream(self, prompt: str, files: Optional[Sequence[File]] = None)
# 此处为调用了Plugin之后直接结束的Plugin
curr_step = DEFAULT_FINISH_STEP
yield curr_step, new_messages

if isinstance(curr_step, EndStep):
is_finished = True
end_step_msgs.extend(new_messages)
Expand All @@ -270,7 +270,6 @@ async def _run_stream(self, prompt: str, files: Optional[Sequence[File]] = None)
end_step_msg = AIMessage(content="".join([item.content for item in end_step_msgs]))
self.memory.add_message(end_step_msg)


async def _schema_format(self, llm_resp, chat_history):
"""Convert the LLM response to the agent response schema.
Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# cd ERNIE-SDK/erniebot-agent/tests
# EB_AGENT_ACCESS_TOKEN=<token> pytest integration_tests/agents/test_functional_agent_stream.py -s


@pytest.fixture(scope="module")
def llm():
return ERNIEBot(model="ernie-3.5")
Expand All @@ -36,7 +37,7 @@ async def test_function_agent_run_one_hit(llm, tool, memory):

run_logs = []
async for step, msgs in agent.run_stream(prompt):
run_logs.append((step, msgs))
run_logs.append((step, msgs))

assert len(agent.memory.get_messages()) == 2
assert isinstance(agent.memory.get_messages()[0], HumanMessage)
Expand All @@ -48,11 +49,12 @@ async def test_function_agent_run_one_hit(llm, tool, memory):
assert run_logs[0][1][1].name == run_logs[0][1][0].function_call["name"]
assert json.loads(run_logs[0][1][1].content) == {"formula_result": 5}
assert isinstance(agent.memory.get_messages()[1], AIMessage)

steps = [step for step, msgs in run_logs if step.__class__.__name__ != "EndStep"]
assert len(steps) == 1
assert steps[0].info["tool_name"] == tool.tool_name


@pytest.mark.asyncio
async def test_function_agent_run_no_hit(llm, tool, memory):
agent = FunctionAgent(llm=llm, tools=[tool], memory=memory)
Expand All @@ -66,7 +68,9 @@ async def test_function_agent_run_no_hit(llm, tool, memory):
assert isinstance(agent.memory.get_messages()[0], HumanMessage)
assert agent.memory.get_messages()[0].content == prompt
assert isinstance(run_logs[0][1][0], AIMessage)
end_step_msg = "".join([msg[0].content for step, msg in run_logs if step.__class__.__name__ == "EndStep"])
end_step_msg = "".join(
[msg[0].content for step, msg in run_logs if step.__class__.__name__ == "EndStep"]
)
assert end_step_msg == agent.memory.get_messages()[1].content

steps = [step for step, msgs in run_logs if step.__class__.__name__ != "EndStep"]
Expand All @@ -86,12 +90,10 @@ async def test_function_agent_run_no_tool(llm, memory, prompt):
assert isinstance(agent.memory.get_messages()[0], HumanMessage)
assert agent.memory.get_messages()[0].content == prompt
assert isinstance(run_logs[0][1][0], AIMessage)
end_step_msg = "".join([msg[0].content for step, msg in run_logs if step.__class__.__name__ == "EndStep"])
end_step_msg = "".join(
[msg[0].content for step, msg in run_logs if step.__class__.__name__ == "EndStep"]
)
assert end_step_msg == agent.memory.get_messages()[1].content

steps = [step for step, msgs in run_logs if step.__class__.__name__ != "EndStep"]
assert len(steps) == 0




0 comments on commit 3fb74a1

Please sign in to comment.