Skip to content

Commit

Permalink
根据review意见进行修改
Browse files Browse the repository at this point in the history
  • Loading branch information
xiabo0816 committed Apr 30, 2024
1 parent 6ad4797 commit 2d3c851
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 32 deletions.
30 changes: 24 additions & 6 deletions erniebot-agent/src/erniebot_agent/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,14 @@ async def run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Agen
async def run_stream(
self, prompt: str, files: Optional[Sequence[File]] = None
) -> AsyncIterator[Tuple[AgentStep, List[Message]]]:
"""Run the agent asynchronously.
"""Run the agent asynchronously, returning an async iterator of responses.
Args:
prompt: A natural language text describing the task that the agent
should perform.
files: A list of files that the agent can use to perform the task.
Returns:
Response AgentStep from the agent.
Iterator of responses from the agent.
"""
if files:
await self._ensure_managed_files(files)
Expand All @@ -164,7 +164,7 @@ async def run_stream(
await self._callback_manager.on_run_end(
agent=self,
response=AgentResponse(
text="Agent run stopped early.",
text="Agent run stopped.",
chat_history=self.memory.get_messages(),
steps=[step],
status="STOPPED",
Expand All @@ -177,7 +177,7 @@ async def run_llm(
messages: List[Message],
**llm_opts: Any,
) -> LLMResponse:
"""Run the LLM asynchronously.
"""Run the LLM asynchronously, returning final response.
Args:
messages: The input messages.
Expand All @@ -202,14 +202,14 @@ async def run_llm_stream(
messages: List[Message],
**llm_opts: Any,
) -> AsyncIterator[LLMResponse]:
"""Run the LLM asynchronously.
"""Run the LLM asynchronously, returning an async iterator of responses
Args:
messages: The input messages.
llm_opts: Options to pass to the LLM.
Returns:
Response from the LLM.
Iterator of responses from the LLM.
"""
llm_resp = None
await self._callback_manager.on_llm_start(agent=self, llm=self.llm, messages=messages)
Expand Down Expand Up @@ -305,6 +305,15 @@ async def _run_stream(
yield only_for_mypy_type_check

async def _run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse:
"""Run the LLM with the given messages and options.
Args:
messages: The input messages.
opts: Options to pass to the LLM.
Returns:
Response from the LLM.
"""
for reserved_opt in ("stream", "system", "plugins"):
if reserved_opt in opts:
raise TypeError(f"`{reserved_opt}` should not be set.")
Expand All @@ -325,6 +334,15 @@ async def _run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse:
return LLMResponse(message=llm_ret)

async def _run_llm_stream(self, messages: List[Message], **opts: Any) -> AsyncIterator[LLMResponse]:
"""Run the LLM, yielding an async iterator of responses.
Args:
messages: The input messages.
opts: Options to pass to the LLM.
Returns:
Async iterator of responses from the LLM.
"""
for reserved_opt in ("stream", "system", "plugins"):
if reserved_opt in opts:
raise TypeError(f"`{reserved_opt}` should not be set.")
Expand Down
20 changes: 11 additions & 9 deletions erniebot-agent/src/erniebot_agent/agents/function_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
chat_history.append(run_input)

for tool in self._first_tools:
curr_step, new_messages = await self._first_tool_step(chat_history, selected_tool=tool)
curr_step, new_messages = await self._call_first_tools(chat_history, selected_tool=tool)
if not isinstance(curr_step, EndStep):
chat_history.extend(new_messages)
num_steps_taken += 1
Expand Down Expand Up @@ -181,21 +181,21 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
response = self._create_stopped_response(chat_history, steps_taken)
return response

async def _first_tool_step(
async def _call_first_tools(
self, chat_history: List[Message], selected_tool: Optional[BaseTool] = None
) -> Tuple[AgentStep, List[Message]]:
input_messages = self.memory.get_messages() + chat_history
if selected_tool is None:
llm_resp = await self.run_llm(messages=input_messages)
return await self._schema_format(llm_resp, chat_history)
return await self._process_step(llm_resp, chat_history)

tool_choice = {"type": "function", "function": {"name": selected_tool.tool_name}}
llm_resp = await self.run_llm(
messages=input_messages,
functions=[selected_tool.function_call_schema()], # only regist one tool
tool_choice=tool_choice,
)
return await self._schema_format(llm_resp, chat_history)
return await self._process_step(llm_resp, chat_history)

async def _step(self, chat_history: List[Message]) -> Tuple[AgentStep, List[Message]]:
"""Run a step of the agent.
Expand All @@ -206,7 +206,7 @@ async def _step(self, chat_history: List[Message]) -> Tuple[AgentStep, List[Mess
"""
input_messages = self.memory.get_messages() + chat_history
llm_resp = await self.run_llm(messages=input_messages)
return await self._schema_format(llm_resp, chat_history)
return await self._process_step(llm_resp, chat_history)

async def _step_stream(
self, chat_history: List[Message]
Expand All @@ -219,7 +219,7 @@ async def _step_stream(
"""
input_messages = self.memory.get_messages() + chat_history
async for llm_resp in self.run_llm_stream(messages=input_messages):
yield await self._schema_format(llm_resp, chat_history)
yield await self._process_step(llm_resp, chat_history)

async def _run_stream(
self, prompt: str, files: Optional[Sequence[File]] = None
Expand All @@ -244,7 +244,7 @@ async def _run_stream(
chat_history.append(run_input)

for tool in self._first_tools:
curr_step, new_messages = await self._first_tool_step(chat_history, selected_tool=tool)
curr_step, new_messages = await self._call_first_tools(chat_history, selected_tool=tool)
if not isinstance(curr_step, EndStep):
chat_history.extend(new_messages)
num_steps_taken += 1
Expand Down Expand Up @@ -285,8 +285,8 @@ async def _run_stream(
end_step_msg = AIMessage(content="".join([item.content for item in end_step_msgs]))
self.memory.add_message(end_step_msg)

async def _schema_format(self, llm_resp, chat_history) -> Tuple[AgentStep, List[Message]]:
"""Convert the LLM response to the agent response schema.
async def _process_step(self, llm_resp, chat_history) -> Tuple[AgentStep, List[Message]]:
"""Process and execute a step of the agent from LLM response.
Args:
llm_resp: The LLM response to convert.
chat_history: The chat history to provide to the agent.
Expand All @@ -296,6 +296,7 @@ async def _schema_format(self, llm_resp, chat_history) -> Tuple[AgentStep, List[
new_messages: List[Message] = []
output_message = llm_resp.message # AIMessage
new_messages.append(output_message)
# handle function call
if output_message.function_call is not None:
tool_name = output_message.function_call["name"]
tool_args = output_message.function_call["arguments"]
Expand All @@ -310,6 +311,7 @@ async def _schema_format(self, llm_resp, chat_history) -> Tuple[AgentStep, List[
),
new_messages,
)
# handle plugin info with input/output files
elif output_message.plugin_info is not None:
file_manager = self.get_file_manager()
return (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from erniebot_agent.agents import FunctionAgent
from erniebot_agent.agents.schema import EndStep
from erniebot_agent.chat_models import ERNIEBot
from erniebot_agent.memory import WholeMemory
from erniebot_agent.memory.messages import (
Expand All @@ -16,9 +17,6 @@
ONE_HIT_PROMPT = "1+4等于几?"
NO_HIT_PROMPT = "深圳今天天气怎么样?"

# cd ERNIE-SDK/erniebot-agent/tests
# EB_AGENT_ACCESS_TOKEN=<token> pytest integration_tests/agents/test_functional_agent_stream.py -s


@pytest.fixture(scope="module")
def llm():
Expand Down Expand Up @@ -55,9 +53,9 @@ async def test_function_agent_run_one_hit(llm, tool, memory):
assert json.loads(run_logs[0][1][1].content) == {"formula_result": 5}
assert isinstance(agent.memory.get_messages()[1], AIMessage)

steps = [step for step, msgs in run_logs if step.__class__.__name__ != "EndStep"]
assert len(steps) == 1
assert steps[0].info["tool_name"] == tool.tool_name
tool_steps = [step for step, msgs in run_logs if not isinstance(step, EndStep)]
assert len(tool_steps) == 1
assert tool_steps[0].info["tool_name"] == tool.tool_name


@pytest.mark.asyncio
Expand All @@ -73,14 +71,12 @@ async def test_function_agent_run_no_hit(llm, tool, memory):
assert isinstance(agent.memory.get_messages()[0], HumanMessage)
assert agent.memory.get_messages()[0].content == prompt
assert isinstance(run_logs[0][1][0], AIMessage)
end_step_msg = "".join(
[msg[0].content for step, msg in run_logs if step.__class__.__name__ == "EndStep"]
)
end_step_msg = "".join([msg[0].content for step, msg in run_logs if isinstance(step, EndStep)])
assert end_step_msg == agent.memory.get_messages()[1].content

steps = [step for step, msgs in run_logs if step.__class__.__name__ != "EndStep"]
assert len(steps) == 0

tool_steps = [step for step, msgs in run_logs if not isinstance(step, EndStep)]
assert len(tool_steps) == 0

@pytest.mark.asyncio
@pytest.mark.parametrize("prompt", [ONE_HIT_PROMPT, NO_HIT_PROMPT])
Expand All @@ -95,10 +91,8 @@ async def test_function_agent_run_no_tool(llm, memory, prompt):
assert isinstance(agent.memory.get_messages()[0], HumanMessage)
assert agent.memory.get_messages()[0].content == prompt
assert isinstance(run_logs[0][1][0], AIMessage)
end_step_msg = "".join(
[msg[0].content for step, msg in run_logs if step.__class__.__name__ == "EndStep"]
)
end_step_msg = "".join([msg[0].content for step, msg in run_logs if isinstance(step, EndStep)])
assert end_step_msg == agent.memory.get_messages()[1].content

steps = [step for step, msgs in run_logs if step.__class__.__name__ != "EndStep"]
assert len(steps) == 0
tool_steps = [step for step, msgs in run_logs if not isinstance(step, EndStep)]
assert len(tool_steps) == 0

0 comments on commit 2d3c851

Please sign in to comment.