From 3b2792964bfa472d9f87c409642311885a87cef7 Mon Sep 17 00:00:00 2001 From: xiabo Date: Wed, 3 Jul 2024 20:30:08 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8F=AA=E6=9C=89=E5=85=B3=E9=97=AD=E5=A4=9A?= =?UTF-8?q?=E6=AD=A5=E5=B7=A5=E5=85=B7=E8=B0=83=E7=94=A8=E6=97=B6=E6=89=8D?= =?UTF-8?q?=E7=94=A8=E6=B5=81=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/erniebot_agent/agents/agent.py | 34 +++++++++++++++++-- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/erniebot-agent/src/erniebot_agent/agents/agent.py b/erniebot-agent/src/erniebot_agent/agents/agent.py index d04a9c96..ed02f714 100644 --- a/erniebot-agent/src/erniebot_agent/agents/agent.py +++ b/erniebot-agent/src/erniebot_agent/agents/agent.py @@ -361,9 +361,37 @@ async def _run_llm_stream(self, messages: List[Message], **opts: Any) -> AsyncIt ) opts["system"] = self.system.content if self.system is not None else None opts["plugins"] = self._plugins - llm_ret = await self.llm.chat(messages, stream=True, functions=functions, **opts) - async for msg in llm_ret: - yield LLMResponse(message=msg) + # llm_ret = await self.llm.chat(messages, stream=True, functions=functions, **opts) + # async for msg in llm_ret: + # yield LLMResponse(message=msg) + # print(self.llm.extra_params.get("enable_multi_step_tool_call")) + # 流式时,无法同时处理多个工具调用 + # 所以只有关闭多步工具调用时才用流式 + if self.llm.extra_data.get("multi_step_tool_call_close", True): + llm_ret = await self.llm.chat(messages, stream=True, functions=functions, **opts) + async for msg in llm_ret: + print("_run_llm_stream", msg) + yield LLMResponse(message=msg) + else: + llm_ret = await self.llm.chat(messages, stream=False, functions=functions, **opts) + class MyAsyncIterator: + def __init__(self, data): + self.data = data + self.index = 0 + async def __anext__(self): + if self.index < len(self.data): + result = self.data[self.index] + self.index += 1 + return result + else: + raise StopAsyncIteration + def __aiter__(self): + return self + + llm_ret = MyAsyncIterator([llm_ret]) + async for msg in llm_ret: + print("_run_llm_stream", msg) + yield LLMResponse(message=msg) async def _run_tool(self, tool: BaseTool, tool_args: str) -> ToolResponse: parsed_tool_args = self._parse_tool_args(tool_args)