Skip to content

Commit

Permalink
使用python -m mypy src检查和优化代码
Browse files Browse the repository at this point in the history
  • Loading branch information
xiabo0816 committed Apr 28, 2024
1 parent cc8b31c commit 72f97f6
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 11 deletions.
30 changes: 23 additions & 7 deletions erniebot-agent/src/erniebot_agent/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from erniebot_agent.agents.callback.handlers.base import CallbackHandler
from erniebot_agent.agents.mixins import GradioMixin
from erniebot_agent.agents.schema import (
DEFAULT_FINISH_STEP,
AgentResponse,
AgentStep,
LLMResponse,
Expand Down Expand Up @@ -140,7 +141,7 @@ async def run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Agen
@final
async def run_stream(
self, prompt: str, files: Optional[Sequence[File]] = None
) -> AsyncIterator[AgentStep]:
) -> AsyncIterator[Tuple[AgentStep, List[Message]]]:
"""Run the agent asynchronously.
Args:
Expand All @@ -154,14 +155,21 @@ async def run_stream(
await self._ensure_managed_files(files)
await self._callback_manager.on_run_start(agent=self, prompt=prompt)
try:
async for step in self._run_stream(prompt, files):
yield step
async for step, msg in self._run_stream(prompt, files):
yield (step, msg)
except BaseException as e:
await self._callback_manager.on_run_error(agent=self, error=e)
raise e
else:
await self._callback_manager.on_run_end(agent=self, response=None)
return
await self._callback_manager.on_run_end(
agent=self,
response=AgentResponse(
text="Agent run stopped early.",
chat_history=self.memory.get_messages(),
steps=[step],
status="STOPPED",
),
)

@final
async def run_llm(
Expand Down Expand Up @@ -284,8 +292,16 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
@abc.abstractmethod
async def _run_stream(
self, prompt: str, files: Optional[Sequence[File]] = None
) -> AsyncIterator[AgentStep]:
raise NotImplementedError
) -> AsyncIterator[Tuple[AgentStep, List[Message]]]:
"""
Abstract asynchronous generator method that should be implemented by subclasses.
This method should yield a sequence of (AgentStep, List[Message]) tuples based on the given
prompt and optionally accompanying files.
"""
if False:
# This conditional block is strictly for static type-checking purposes (e.g., mypy) and will not be executed.
only_for_mypy_type_check: Tuple[AgentStep, List[Message]] = (DEFAULT_FINISH_STEP, [])
yield only_for_mypy_type_check

async def _run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse:
for reserved_opt in ("stream", "system", "plugins"):
Expand Down
9 changes: 5 additions & 4 deletions erniebot-agent/src/erniebot_agent/agents/function_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
return response

async def _first_tool_step(
self, chat_history: List[Message], selected_tool: BaseTool = None
self, chat_history: List[Message], selected_tool: Optional[BaseTool] = None
) -> Tuple[AgentStep, List[Message]]:
input_messages = self.memory.get_messages() + chat_history
if selected_tool is None:
Expand Down Expand Up @@ -254,7 +254,6 @@ async def _run_stream(
_logger.warning(f"Selected tool [{tool.tool_name}] not work")

is_finished = False
curr_step = None
new_messages = []
end_step_msgs = []
while is_finished is False:
Expand All @@ -274,17 +273,19 @@ async def _run_stream(
curr_step = DEFAULT_FINISH_STEP
yield curr_step, new_messages

if isinstance(curr_step, EndStep):
elif isinstance(curr_step, EndStep):
is_finished = True
end_step_msgs.extend(new_messages)
yield curr_step, new_messages
else:
raise RuntimeError("Invalid step type")
chat_history.extend(new_messages)

self.memory.add_message(run_input)
end_step_msg = AIMessage(content="".join([item.content for item in end_step_msgs]))
self.memory.add_message(end_step_msg)

async def _schema_format(self, llm_resp, chat_history):
async def _schema_format(self, llm_resp, chat_history) -> Tuple[AgentStep, List[Message]]:
"""Convert the LLM response to the agent response schema.
Args:
llm_resp: The LLM response to convert.
Expand Down

0 comments on commit 72f97f6

Please sign in to comment.