diff --git a/erniebot-agent/examples/cv_agent/CV_agent.py b/erniebot-agent/examples/cv_agent/CV_agent.py index 6a397da70..231619442 100644 --- a/erniebot-agent/examples/cv_agent/CV_agent.py +++ b/erniebot-agent/examples/cv_agent/CV_agent.py @@ -2,7 +2,7 @@ from erniebot_agent.agents.functional_agent import FunctionalAgent from erniebot_agent.chat_models.erniebot import ERNIEBot -from erniebot_agent.file_io import ( +from erniebot_agent.file import ( configure_global_file_manager, get_global_file_manager, ) @@ -25,7 +25,7 @@ def __init__(self): async def run_agent(): - file_manager = get_global_file_manager() + file_manager = await get_global_file_manager() seg_file = await file_manager.create_file_from_path(file_path="cityscapes_demo.png", file_type="local") clas_file = await file_manager.create_file_from_path(file_path="class_img.jpg", file_type="local") ocr_file = await file_manager.create_file_from_path(file_path="ch.png", file_type="local") diff --git a/erniebot-agent/examples/plugins/multiple_plugins.py b/erniebot-agent/examples/plugins/multiple_plugins.py index 29d628813..a64dfbd43 100644 --- a/erniebot-agent/examples/plugins/multiple_plugins.py +++ b/erniebot-agent/examples/plugins/multiple_plugins.py @@ -6,9 +6,9 @@ from erniebot_agent.agents.callback.default import get_no_ellipsis_callback from erniebot_agent.agents.functional_agent import FunctionalAgent from erniebot_agent.chat_models.erniebot import ERNIEBot -from erniebot_agent.file_io import get_global_file_manager +from erniebot_agent.file import get_global_file_manager from erniebot_agent.memory.sliding_window_memory import SlidingWindowMemory -from erniebot_agent.messages import AIMessage, HumanMessage, Message +from erniebot_agent.memory import AIMessage, HumanMessage, Message from erniebot_agent.tools.base import Tool from erniebot_agent.tools.calculator_tool import CalculatorTool from erniebot_agent.tools.schema import ToolParameterView @@ -32,7 +32,7 @@ async def __call__(self, input_file_id: str, repeat_times: int) -> Dict[str, Any if "" in input_file_id: input_file_id = input_file_id.split("")[0] - file_manager = get_global_file_manager() + file_manager = await get_global_file_manager() input_file = file_manager.look_up_file_by_id(input_file_id) if input_file is None: raise RuntimeError("File not found") @@ -109,20 +109,20 @@ def examples(self) -> List[Message]: # TODO(shiyutang): replace this when model is online llm = ERNIEBot(model="ernie-3.5", api_type="custom") memory = SlidingWindowMemory(max_round=1) -file_manager = get_global_file_manager() # plugins = ["ChatFile", "eChart"] plugins: List[str] = [] agent = FunctionalAgent( llm=llm, tools=[TextRepeaterTool(), TextRepeaterNoFileTool(), CalculatorTool()], memory=memory, - file_manager=file_manager, callbacks=get_no_ellipsis_callback(), plugins=plugins, ) async def run_agent(): + file_manager = await get_global_file_manager() + docx_file = await file_manager.create_file_from_path( file_path="浅谈牛奶的营养与消费趋势.docx", file_type="remote", diff --git a/erniebot-agent/examples/rpg_game_agent.py b/erniebot-agent/examples/rpg_game_agent.py index 2c9aecfba..4a03edc70 100644 --- a/erniebot-agent/examples/rpg_game_agent.py +++ b/erniebot-agent/examples/rpg_game_agent.py @@ -24,9 +24,7 @@ from erniebot_agent.agents.base import Agent from erniebot_agent.agents.schema import AgentFile, AgentResponse from erniebot_agent.chat_models.erniebot import ERNIEBot -from erniebot_agent.file_io import get_global_file_manager -from erniebot_agent.file_io.base import File -from erniebot_agent.file_io.file_manager import FileManager +from erniebot_agent.file.base import File from erniebot_agent.memory.sliding_window_memory import SlidingWindowMemory from erniebot_agent.messages import AIMessage, HumanMessage, SystemMessage from erniebot_agent.tools.base import BaseTool @@ -87,7 +85,6 @@ def __init__( tools=tools, system_message=system_message, ) - self.file_manager: FileManager = get_global_file_manager() async def handle_tool(self, tool_name: str, tool_args: str) -> str: tool_response = await self._async_run_tool( diff --git a/erniebot-agent/src/erniebot_agent/agents/base.py b/erniebot-agent/src/erniebot_agent/agents/base.py index 10646cd7e..623480b5d 100644 --- a/erniebot-agent/src/erniebot_agent/agents/base.py +++ b/erniebot-agent/src/erniebot_agent/agents/base.py @@ -76,18 +76,16 @@ def __init__( self._callback_manager = callbacks else: self._callback_manager = CallbackManager(callbacks) - if file_manager is None: - file_manager = get_global_file_manager() - self.plugins = plugins self._file_manager = file_manager + self._plugins = plugins self._init_file_repr() def _init_file_repr(self): self.file_needs_url = False - if self.plugins: + if self._plugins: PLUGIN_WO_FILE = ["eChart"] - for plugin in self.plugins: + for plugin in self._plugins: if plugin not in PLUGIN_WO_FILE: self.file_needs_url = True @@ -175,12 +173,8 @@ async def _sniff_and_extract_files_from_args( for val in args.values(): if isinstance(val, str): if protocol.is_file_id(val): - if self._file_manager is None: - logger.warning( - f"A file is used by {repr(tool)}, but the agent has no file manager to fetch it." - ) - continue - file = self._file_manager.look_up_file_by_id(val) + file_manager = await self._get_file_manager() + file = file_manager.look_up_file_by_id(val) if file is None: raise RuntimeError(f"Unregistered file with ID {repr(val)} is used by {repr(tool)}.") agent_files.append(AgentFile(file=file, type=file_type, used_by=tool.tool_name)) @@ -190,3 +184,10 @@ async def _sniff_and_extract_files_from_args( for item in val: agent_files.extend(await self._sniff_and_extract_files_from_args(item, tool, file_type)) return agent_files + + async def _get_file_manager(self) -> FileManager: + if self._file_manager is None: + file_manager = await get_global_file_manager() + else: + file_manager = self._file_manager + return file_manager diff --git a/erniebot-agent/src/erniebot_agent/file/global_file_manager.py b/erniebot-agent/src/erniebot_agent/file/global_file_manager.py index eb20814e3..4b17c834f 100644 --- a/erniebot-agent/src/erniebot_agent/file/global_file_manager.py +++ b/erniebot-agent/src/erniebot_agent/file/global_file_manager.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio from typing import Any, Optional import asyncio_atexit # type: ignore @@ -21,12 +22,14 @@ from erniebot_agent.utils import config_from_environ as C _global_file_manager: Optional[FileManager] = None +_lock = asyncio.Lock() async def get_global_file_manager() -> FileManager: global _global_file_manager - if _global_file_manager is None: - _global_file_manager = await _create_default_file_manager(access_token=None, save_dir=None) + async with _lock: + if _global_file_manager is None: + _global_file_manager = await _create_default_file_manager(access_token=None, save_dir=None) return _global_file_manager @@ -34,11 +37,12 @@ async def configure_global_file_manager( access_token: Optional[str] = None, save_dir: Optional[str] = None, **opts: Any ) -> None: global _global_file_manager - if _global_file_manager is not None: - raise RuntimeError( - "The global file manager can only be configured once before calling `get_global_file_manager`." - ) - _global_file_manager = await _create_default_file_manager(access_token=access_token, save_dir=save_dir, **opts) + async with _lock: + if _global_file_manager is not None: + raise RuntimeError( + "The global file manager can only be configured once before calling `get_global_file_manager`." + ) + _global_file_manager = await _create_default_file_manager(access_token=access_token, save_dir=save_dir, **opts) async def _create_default_file_manager( diff --git a/erniebot-agent/src/erniebot_agent/tools/remote_tool.py b/erniebot-agent/src/erniebot_agent/tools/remote_tool.py index ef3ab1a8d..06b387ce6 100644 --- a/erniebot-agent/src/erniebot_agent/tools/remote_tool.py +++ b/erniebot-agent/src/erniebot_agent/tools/remote_tool.py @@ -4,6 +4,7 @@ import json from copy import deepcopy from typing import Any, Dict, List, Optional, Type +from erniebot_agent.file import get_global_file_manager import requests @@ -45,7 +46,7 @@ def __init__( server_url: str, headers: dict, version: str, - file_manager: FileManager, + file_manager: Optional[FileManager], examples: Optional[List[Message]] = None, tool_name_prefix: Optional[str] = None, ) -> None: @@ -86,11 +87,13 @@ async def fileid_to_byte(file_id, file_manager): async def convert_to_file_data(file_data: str, format: str): value = file_data.replace("", "").replace("", "") - byte_value = await fileid_to_byte(value, self.file_manager) + byte_value = await fileid_to_byte(value, file_manager) if format == "byte": byte_value = base64.b64encode(byte_value).decode() return byte_value + file_manager = await self._get_file_manager() + # 1. replace fileid with byte string parameter_file_info = get_file_info_from_param_view(self.tool_view.parameters) for key in tool_arguments.keys(): @@ -203,19 +206,21 @@ async def send_request(self, tool_arguments: Dict[str, Any]) -> dict: if len(returns_file_infos) == 0 and is_json_response(response): return response.json() + file_manager = await self._get_file_manager() + file_metadata = {"tool_name": self.tool_name} if is_json_response(response) and len(returns_file_infos) > 0: response_json = response.json() file_info = await parse_file_from_json_response( response_json, - file_manager=self.file_manager, + file_manager=file_manager, param_view=self.tool_view.returns, # type: ignore tool_name=self.tool_name, ) response_json.update(file_info) return response_json file = await parse_file_from_response( - response, self.file_manager, file_infos=returns_file_infos, file_metadata=file_metadata + response, file_manager, file_infos=returns_file_infos, file_metadata=file_metadata ) if file is not None: @@ -293,6 +298,13 @@ def __adhoc_post_process__(self, tool_response: dict) -> dict: result.pop(key) return tool_response + async def _get_file_manager(self) -> FileManager: + if self.file_manager is None: + file_manager = await get_global_file_manager() + else: + file_manager = self.file_manager + return file_manager + class RemoteToolRegistor: def __init__(self) -> None: diff --git a/erniebot-agent/src/erniebot_agent/tools/remote_toolkit.py b/erniebot-agent/src/erniebot_agent/tools/remote_toolkit.py index 62cdbf8b3..94f7e60c1 100644 --- a/erniebot-agent/src/erniebot_agent/tools/remote_toolkit.py +++ b/erniebot-agent/src/erniebot_agent/tools/remote_toolkit.py @@ -12,7 +12,6 @@ from openapi_spec_validator.readers import read_from_filename from yaml import safe_dump -from erniebot_agent.file import get_global_file_manager from erniebot_agent.file.file_manager import FileManager from erniebot_agent.memory.messages import ( AIMessage, @@ -201,9 +200,6 @@ def from_openapi_dict( ) ) - if file_manager is None: - file_manager = get_global_file_manager() - return RemoteToolkit( openapi=openapi_dict["openapi"], info=info, diff --git a/erniebot-agent/tests/integration_tests/apihub/test_pp_shituv2.py b/erniebot-agent/tests/integration_tests/apihub/test_pp_shituv2.py index bdf8262dd..d57373e86 100644 --- a/erniebot-agent/tests/integration_tests/apihub/test_pp_shituv2.py +++ b/erniebot-agent/tests/integration_tests/apihub/test_pp_shituv2.py @@ -1,8 +1,8 @@ from __future__ import annotations +from erniebot_agent.file.file_manager import FileManager import pytest -from erniebot_agent.file_io import get_global_file_manager from erniebot_agent.tools import RemoteToolkit from .base import RemoteToolTesting @@ -17,13 +17,13 @@ async def test_pp_shituv2(self): agent = self.get_agent(toolkit) - file_manager = get_global_file_manager() - file_path = self.download_file( - "https://paddlenlp.bj.bcebos.com/ebagent/ci/fixtures/remote-tools/pp_shituv2_input_img.png" - ) - file = await file_manager.create_file_from_path(file_path) + async with FileManager() as file_manager: + file_path = self.download_file( + "https://paddlenlp.bj.bcebos.com/ebagent/ci/fixtures/remote-tools/pp_shituv2_input_img.png" + ) + file = await file_manager.create_file_from_path(file_path) - result = await agent.async_run("对这张图片进行通用识别,包含的文件为:", files=[file]) - self.assertEqual(len(result.files), 2) - self.assertEqual(len(result.actions), 1) - self.assertIn("file-", result.text) + result = await agent.async_run("对这张图片进行通用识别,包含的文件为:", files=[file]) + self.assertEqual(len(result.files), 2) + self.assertEqual(len(result.actions), 1) + self.assertIn("file-", result.text)