diff --git a/erniebot-agent/Makefile b/erniebot-agent/Makefile index 0b198b07a..1f527f474 100644 --- a/erniebot-agent/Makefile +++ b/erniebot-agent/Makefile @@ -1,6 +1,9 @@ -.DEFAULT_GOAL = format lint type_check +.DEFAULT_GOAL = dev files_to_format_and_lint = src examples tests +.PHONY: dev +dev: format lint type_check + .PHONY: format format: python -m black $(files_to_format_and_lint) diff --git a/erniebot-agent/examples/cv_agent/CV_agent.py b/erniebot-agent/examples/cv_agent/CV_agent.py index 3eae060f0..2cd56892b 100644 --- a/erniebot-agent/examples/cv_agent/CV_agent.py +++ b/erniebot-agent/examples/cv_agent/CV_agent.py @@ -2,7 +2,7 @@ from erniebot_agent.agents.functional_agent import FunctionalAgent from erniebot_agent.chat_models.erniebot import ERNIEBot -from erniebot_agent.file_io import get_file_manager +from erniebot_agent.file_io import get_global_file_manager from erniebot_agent.memory.whole_memory import WholeMemory from erniebot_agent.tools import RemoteToolkit @@ -17,7 +17,7 @@ def __init__(self): llm = ERNIEBot(model="ernie-3.5", api_type="aistudio", access_token="") toolkit = CVToolkit() memory = WholeMemory() -file_manager = get_file_manager() +file_manager = get_global_file_manager(access_token=None) agent = FunctionalAgent(llm=llm, tools=toolkit.tools, memory=memory, file_manager=file_manager) diff --git a/erniebot-agent/examples/plugins/multiple_plugins.py b/erniebot-agent/examples/plugins/multiple_plugins.py index d858410e8..b79f1533d 100644 --- a/erniebot-agent/examples/plugins/multiple_plugins.py +++ b/erniebot-agent/examples/plugins/multiple_plugins.py @@ -6,7 +6,7 @@ from erniebot_agent.agents.callback.default import get_no_ellipsis_callback from erniebot_agent.agents.functional_agent import FunctionalAgent from erniebot_agent.chat_models.erniebot import ERNIEBot -from erniebot_agent.file_io import get_file_manager +from erniebot_agent.file_io import get_global_file_manager from erniebot_agent.memory.sliding_window_memory import SlidingWindowMemory from erniebot_agent.messages import AIMessage, HumanMessage, Message from erniebot_agent.tools.base import Tool @@ -32,7 +32,7 @@ async def __call__(self, input_file_id: str, repeat_times: int) -> Dict[str, Any if "" in input_file_id: input_file_id = input_file_id.split("")[0] - file_manager = get_file_manager() # Access_token needs to be set here. + file_manager = get_global_file_manager(access_token=None) # Access_token needs to be set here. input_file = file_manager.look_up_file_by_id(input_file_id) if input_file is None: raise RuntimeError("File not found") @@ -109,7 +109,7 @@ def examples(self) -> List[Message]: # TODO(shiyutang): replace this when model is online llm = ERNIEBot(model="ernie-3.5", api_type="custom") memory = SlidingWindowMemory(max_round=1) -file_manager = get_file_manager(access_token="") # Access_token needs to be set here. +file_manager = get_global_file_manager(access_token="") # Access_token needs to be set here. # plugins = ["ChatFile", "eChart"] plugins: List[str] = [] agent = FunctionalAgent( diff --git a/erniebot-agent/examples/rpg_game_agent.py b/erniebot-agent/examples/rpg_game_agent.py index 0ef00b23e..b338d7407 100644 --- a/erniebot-agent/examples/rpg_game_agent.py +++ b/erniebot-agent/examples/rpg_game_agent.py @@ -24,7 +24,7 @@ from erniebot_agent.agents.base import Agent from erniebot_agent.agents.schema import AgentFile, AgentResponse from erniebot_agent.chat_models.erniebot import ERNIEBot -from erniebot_agent.file_io import get_file_manager +from erniebot_agent.file_io import get_global_file_manager from erniebot_agent.file_io.base import File from erniebot_agent.file_io.file_manager import FileManager from erniebot_agent.memory.sliding_window_memory import SlidingWindowMemory @@ -87,7 +87,7 @@ def __init__( tools=tools, system_message=system_message, ) - self.file_manager: FileManager = get_file_manager() + self.file_manager: FileManager = get_global_file_manager(access_token) async def handle_tool(self, tool_name: str, tool_args: str) -> str: tool_response = await self._async_run_tool( diff --git a/erniebot-agent/src/erniebot_agent/agents/base.py b/erniebot-agent/src/erniebot_agent/agents/base.py index 630767a83..84b9bd215 100644 --- a/erniebot-agent/src/erniebot_agent/agents/base.py +++ b/erniebot-agent/src/erniebot_agent/agents/base.py @@ -14,7 +14,8 @@ import abc import json -from typing import Any, Dict, List, Literal, Optional, Union +import logging +from typing import Any, Dict, List, Literal, Optional, Union, final from erniebot_agent import file_io from erniebot_agent.agents.callback.callback_manager import CallbackManager @@ -27,15 +28,16 @@ ToolResponse, ) from erniebot_agent.chat_models.base import ChatModel +from erniebot_agent.file_io import protocol from erniebot_agent.file_io.base import File from erniebot_agent.file_io.file_manager import FileManager -from erniebot_agent.file_io.protocol import is_local_file_id, is_remote_file_id from erniebot_agent.memory.base import Memory from erniebot_agent.messages import Message, SystemMessage from erniebot_agent.tools.base import BaseTool from erniebot_agent.tools.tool_manager import ToolManager -from erniebot_agent.utils.gradio_mixin import GradioMixin -from erniebot_agent.utils.logging import logger +from erniebot_agent.utils.mixins import GradioMixin + +logger = logging.getLogger(__name__) class BaseAgent(metaclass=abc.ABCMeta): @@ -76,7 +78,7 @@ def __init__( else: self._callback_manager = CallbackManager(callbacks) if file_manager is None: - file_manager = file_io.get_file_manager() + file_manager = file_io.get_global_file_manager(access_token=None) self.plugins = plugins self._file_manager = file_manager self._init_file_repr() @@ -94,6 +96,7 @@ def _init_file_repr(self): def tools(self) -> List[BaseTool]: return self._tool_manager.get_tools() + @final async def async_run(self, prompt: str, files: Optional[List[File]] = None) -> AgentResponse: await self._callback_manager.on_run_start(agent=self, prompt=prompt) agent_resp = await self._async_run(prompt, files) @@ -113,6 +116,7 @@ def reset_memory(self) -> None: async def _async_run(self, prompt: str, files: Optional[List[File]] = None) -> AgentResponse: raise NotImplementedError + @final async def _async_run_tool(self, tool_name: str, tool_args: str) -> ToolResponse: tool = self._tool_manager.get_tool(tool_name) await self._callback_manager.on_tool_start(agent=self, tool=tool, input_args=tool_args) @@ -124,6 +128,7 @@ async def _async_run_tool(self, tool_name: str, tool_args: str) -> ToolResponse: await self._callback_manager.on_tool_end(agent=self, tool=tool, response=tool_resp) return tool_resp + @final async def _async_run_llm(self, messages: List[Message], **opts: Any) -> LLMResponse: await self._callback_manager.on_llm_start(agent=self, llm=self.llm, messages=messages) try: @@ -170,16 +175,7 @@ async def _sniff_and_extract_files_from_args( agent_files: List[AgentFile] = [] for val in args.values(): if isinstance(val, str): - if is_local_file_id(val): - if self._file_manager is None: - logger.warning( - f"A file is used by {repr(tool)}, but the agent has no file manager to fetch it." - ) - continue - file = self._file_manager.look_up_file_by_id(val) - if file is None: - raise RuntimeError(f"Unregistered ID {repr(val)} is used by {repr(tool)}.") - elif is_remote_file_id(val): + if protocol.is_file_id(val): if self._file_manager is None: logger.warning( f"A file is used by {repr(tool)}, but the agent has no file manager to fetch it." @@ -187,10 +183,8 @@ async def _sniff_and_extract_files_from_args( continue file = self._file_manager.look_up_file_by_id(val) if file is None: - file = await self._file_manager.retrieve_remote_file_by_id(val) - else: - continue - agent_files.append(AgentFile(file=file, type=file_type, used_by=tool.tool_name)) + raise RuntimeError(f"Unregistered file with ID {repr(val)} is used by {repr(tool)}.") + agent_files.append(AgentFile(file=file, type=file_type, used_by=tool.tool_name)) elif isinstance(val, dict): agent_files.extend(await self._sniff_and_extract_files_from_args(val, tool, file_type)) elif isinstance(val, list) and len(val) > 0 and isinstance(val[0], dict): diff --git a/erniebot-agent/src/erniebot_agent/agents/callback/handlers/logging_handler.py b/erniebot-agent/src/erniebot_agent/agents/callback/handlers/logging_handler.py index e712f99d0..9f980cabd 100644 --- a/erniebot-agent/src/erniebot_agent/agents/callback/handlers/logging_handler.py +++ b/erniebot-agent/src/erniebot_agent/agents/callback/handlers/logging_handler.py @@ -23,9 +23,10 @@ from erniebot_agent.messages import Message from erniebot_agent.tools.base import BaseTool from erniebot_agent.utils.json import to_pretty_json -from erniebot_agent.utils.logging import logger as default_logger from erniebot_agent.utils.output_style import ColoredContent +default_logger = logging.getLogger(__name__) + if TYPE_CHECKING: from erniebot_agent.agents.base import Agent diff --git a/erniebot-agent/src/erniebot_agent/agents/schema.py b/erniebot-agent/src/erniebot_agent/agents/schema.py index eaf6d73cd..75390bd26 100644 --- a/erniebot-agent/src/erniebot_agent/agents/schema.py +++ b/erniebot-agent/src/erniebot_agent/agents/schema.py @@ -18,8 +18,8 @@ from dataclasses import dataclass from typing import Dict, List, Literal, Optional, Tuple, Union +from erniebot_agent.file_io import protocol from erniebot_agent.file_io.base import File -from erniebot_agent.file_io.protocol import extract_file_ids from erniebot_agent.messages import AIMessage, Message @@ -80,6 +80,8 @@ def get_output_files(self) -> List[File]: return [agent_file.file for agent_file in self.files if agent_file.type == "output"] def get_tool_input_output_files(self, tool_name: str) -> Tuple[List[File], List[File]]: + # XXX: If a tool is used mutliple times, all related files will be + # returned in flattened lists. input_files: List[File] = [] output_files: List[File] = [] for agent_file in self.files: @@ -94,7 +96,7 @@ def get_tool_input_output_files(self, tool_name: str) -> Tuple[List[File], List[ def output_dict(self) -> Dict[str, List]: # 1. split the text into parts and add file id to each part - file_ids = extract_file_ids(self.text) + file_ids = protocol.extract_file_ids(self.text) places = [] for file_id in file_ids: diff --git a/erniebot-agent/src/erniebot_agent/file_io/__init__.py b/erniebot-agent/src/erniebot_agent/file_io/__init__.py index 89807bda3..75363924b 100644 --- a/erniebot-agent/src/erniebot_agent/file_io/__init__.py +++ b/erniebot-agent/src/erniebot_agent/file_io/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from erniebot_agent.file_io.factory import get_file_manager +from erniebot_agent.file_io.factory import get_global_file_manager diff --git a/erniebot-agent/src/erniebot_agent/file_io/base.py b/erniebot-agent/src/erniebot_agent/file_io/base.py index f1dcc1a8c..6b1e466b4 100644 --- a/erniebot-agent/src/erniebot_agent/file_io/base.py +++ b/erniebot-agent/src/erniebot_agent/file_io/base.py @@ -23,7 +23,7 @@ def __init__( id: str, filename: str, byte_size: int, - created_at: int, + created_at: str, purpose: str, metadata: Dict[str, Any], ) -> None: @@ -40,7 +40,7 @@ def __eq__(self, other: object) -> bool: if isinstance(other, File): return self.id == other.id else: - return False + return NotImplemented def __repr__(self) -> str: attrs_str = self._get_attrs_str() diff --git a/erniebot-agent/src/erniebot_agent/file_io/caching.py b/erniebot-agent/src/erniebot_agent/file_io/caching.py new file mode 100644 index 000000000..b84625dc8 --- /dev/null +++ b/erniebot-agent/src/erniebot_agent/file_io/caching.py @@ -0,0 +1,327 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import pathlib +import weakref +from typing import Any, Awaitable, Callable, NoReturn, Optional, Tuple, final + +import anyio +from typing_extensions import Self, TypeAlias + +from erniebot_agent.file_io.remote_file import RemoteFile +from erniebot_agent.utils.mixins import Closeable + +_DEFAULT_CACHE_TIMEOUT = 3600 + +logger = logging.getLogger(__name__) + +DiscardCallback: TypeAlias = Callable[[], None] +ContentsReader: TypeAlias = Callable[[], Awaitable[bytes]] +CacheFactory: TypeAlias = Callable[[pathlib.Path, Optional[DiscardCallback]], "FileCache"] + + +def bind_cache_to_remote_file(cache: "FileCache", file: RemoteFile) -> "RemoteFileWithCache": + return RemoteFileWithCache.from_remote_file_and_cache(file, cache) + + +def create_file_cache(cache_path: pathlib.Path, discard_callback: Optional[DiscardCallback]) -> "FileCache": + return FileCache( + cache_path=cache_path, + active=False, + discard_callback=discard_callback, + expire_after=_DEFAULT_CACHE_TIMEOUT, + ) + + +def create_default_file_cache_manager() -> "FileCacheManager": + return FileCacheManager(cache_factory=create_file_cache) + + +@final +class FileCache(object): + def __init__( + self, + *, + cache_path: pathlib.Path, + active: bool, + discard_callback: Optional[DiscardCallback], + expire_after: Optional[float], + ) -> None: + super().__init__() + + self._cache_path = cache_path + self._active = active + self._discard_callback = discard_callback + self._expire_after = expire_after + + self._lock = asyncio.Lock() + self._discarded = False + self._finalizer: Optional[weakref.finalize] + if self._discard_callback is not None: + self._finalizer = weakref.finalize(self, self._discard_callback) + else: + self._finalizer = None + self._expire_handle: Optional[asyncio.TimerHandle] = None + if self._active: + self.activate() + + @property + def cache_path(self) -> pathlib.Path: + return self._cache_path + + @property + def active(self) -> bool: + return self._active + + @property + def discarded(self) -> bool: + return self._discarded + + @property + def alive(self) -> bool: + return not self._discarded + + def __del__(self) -> None: + self._on_discard() + + def __copy__(self) -> NoReturn: + raise RuntimeError(f"{self.__class__.__name__} is not copyable.") + + def __deepcopy__(self, memo: Any) -> NoReturn: + raise RuntimeError(f"{self.__class__.__name__} is not deepcopyable.") + + async def fetch_or_update_contents(self, contents_reader: ContentsReader) -> bytes: + if self._discarded: + raise CacheDiscardedError + async with self._lock: + if self._discarded: + raise CacheDiscardedError + if not self._active: + contents = await self._update_contents(await contents_reader()) + self.activate() + else: + contents = await self._fetch_contents() + return contents + + async def update_contents(self, contents_reader: ContentsReader) -> bytes: + if self._discarded: + raise CacheDiscardedError + new_contents = await contents_reader() + async with self._lock: + if self._discarded: + raise CacheDiscardedError + self.deactivate() + contents = await self._update_contents(new_contents) + self.activate() + return contents + + def activate(self) -> None: + def _expire_callback(cache_ref: weakref.ReferenceType) -> None: + cache = cache_ref() + if cache is not None: + cache._deactivate() + + if self._discarded: + raise CacheDiscardedError + self._cancel_expire_callback() + # Should we inject the event loop from outside? + loop = asyncio.get_running_loop() + if self._expire_after is not None: + self._expire_handle = loop.call_later(self._expire_after, _expire_callback, weakref.ref(self)) + self._active = True + + def deactivate(self) -> None: + self._cancel_expire_callback() + self._deactivate() + + async def discard(self) -> None: + async with self._lock: + if not self._discarded: + self._on_discard() + self._discarded = True + + async def _fetch_contents(self) -> bytes: + return await anyio.Path(self.cache_path).read_bytes() + + async def _update_contents(self, new_contents: bytes) -> bytes: + async with await anyio.open_file(self.cache_path, "wb") as f: + await f.write(new_contents) + return new_contents + + def _deactivate(self) -> None: + self._active = False + + def _on_discard(self) -> None: + self.deactivate() + if self._discard_callback is not None: + if self._finalizer is not None: + self._finalizer.detach() + self._discard_callback() + + def _cancel_expire_callback(self) -> None: + if self._expire_handle is not None: + self._expire_handle.cancel() + self._expire_handle = None + + +@final +class FileCacheManager(Closeable): + def __init__(self, cache_factory: CacheFactory): + super().__init__() + self._cache_factory = cache_factory + self._file_id_to_cache: weakref.WeakValueDictionary[str, FileCache] = weakref.WeakValueDictionary() + self._closed = False + + @property + def closed(self) -> bool: + return self._closed + + async def get_or_create_cache( + self, + file_id: str, + cache_path: pathlib.Path, + *, + discard_callback: Optional[DiscardCallback] = None, + init_cache_in_sync: Optional[bool] = None, + ) -> Tuple[FileCache, bool]: + self.ensure_not_closed() + cache = None + if self._has_cache(file_id): + cache = self._get_cache(file_id) + if cache is not None and cache.alive: + return cache, False + else: + cache = self._create_cache( + file_id, + cache_path, + discard_callback=discard_callback, + init_cache_in_sync=init_cache_in_sync, + ) + self._set_cache(file_id, cache) + return cache, True + + async def get_cache( + self, + file_id: str, + ) -> FileCache: + self.ensure_not_closed() + try: + return self._get_cache(file_id) + except KeyError as e: + raise CacheNotFoundError from e + + async def remove_cache(self, file_id: str) -> None: + self.ensure_not_closed() + try: + cache = self._get_cache(file_id) + except KeyError as e: + raise CacheNotFoundError from e + await cache.discard() + self._delete_cache(file_id) + + async def close(self) -> None: + if not self._closed: + for cache in self._file_id_to_cache.values(): + await cache.discard() + self._clear_caches() + self._closed = True + + def _create_cache( + self, + file_id: str, + cache_path: pathlib.Path, + *, + discard_callback: Optional[DiscardCallback], + init_cache_in_sync: Optional[bool], + ) -> FileCache: + cache = self._cache_factory(cache_path, discard_callback) + if init_cache_in_sync is not None: + if init_cache_in_sync: + cache.activate() + else: + cache.deactivate() + return cache + + def _has_cache(self, file_id: str) -> bool: + return file_id in self._file_id_to_cache + + def _get_cache(self, file_id: str) -> FileCache: + return self._file_id_to_cache[file_id] + + def _set_cache(self, file_id: str, cache: FileCache) -> None: + self._file_id_to_cache[file_id] = cache + + def _delete_cache(self, file_id: str) -> None: + del self._file_id_to_cache[file_id] + + def _clear_caches(self) -> None: + self._file_id_to_cache.clear() + + +class RemoteFileWithCache(RemoteFile): + def __init__( + self, + cache: FileCache, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self._cache = cache + + @classmethod + def from_remote_file_and_cache(cls, file: RemoteFile, cache: FileCache) -> Self: + return cls( + cache, + id=file.id, + filename=file.filename, + byte_size=file.byte_size, + created_at=file.created_at, + purpose=file.purpose, + metadata=file.metadata, + client=file.client, + ) + + @property + def cached(self) -> bool: + return self._cache.alive and self._cache.active + + @property + def cache_path(self) -> Optional[pathlib.Path]: + return self._cache.cache_path if self._cache.alive else None + + async def read_contents(self) -> bytes: + try: + return await self._cache.fetch_or_update_contents(super().read_contents) + except CacheDiscardedError: + return await super().read_contents() + + async def delete(self) -> None: + await super().delete() + await self._cache.discard() + + async def update_cache(self) -> None: + try: + await self._cache.update_contents(super().read_contents) + except CacheDiscardedError: + logger.warning("Cache is no longer available.") + + +class CacheDiscardedError(Exception): + pass + + +class CacheNotFoundError(Exception): + pass diff --git a/erniebot-agent/src/erniebot_agent/file_io/factory.py b/erniebot-agent/src/erniebot_agent/file_io/factory.py index 77f46cab1..22a16a544 100644 --- a/erniebot-agent/src/erniebot_agent/file_io/factory.py +++ b/erniebot-agent/src/erniebot_agent/file_io/factory.py @@ -12,18 +12,43 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio +import atexit import functools -from typing import Optional +from typing import List, Optional from erniebot_agent.file_io.file_manager import FileManager from erniebot_agent.file_io.remote_file import AIStudioFileClient +from erniebot_agent.utils.mixins import Closeable +from erniebot_agent.utils.temp_file import create_tracked_temp_dir + +_objects_to_close: List[Closeable] = [] @functools.lru_cache(maxsize=None) -def get_file_manager(access_token: Optional[str] = None) -> FileManager: +def get_global_file_manager(*, access_token: Optional[str]) -> FileManager: if access_token is None: - # TODO: Use a default global access token. - return FileManager() + file_manager = FileManager(save_dir=create_tracked_temp_dir()) else: remote_file_client = AIStudioFileClient(access_token=access_token) - return FileManager(remote_file_client) + _objects_to_close.append(remote_file_client) + file_manager = FileManager(remote_file_client, save_dir=create_tracked_temp_dir()) + _objects_to_close.append(file_manager) + + return file_manager + + +def _close_objects(): + async def _close_objects_sequentially(): + for obj in _objects_to_close: + await obj.close() + + if _objects_to_close: + # Since async atexit is not officially supported by Python, + # we start a new event loop to do the cleanup. + asyncio.run(_close_objects_sequentially()) + _objects_to_close.clear() + + +# FIXME: The exit handler may not be called when using multiprocessing. +atexit.register(_close_objects) diff --git a/erniebot-agent/src/erniebot_agent/file_io/file_manager.py b/erniebot-agent/src/erniebot_agent/file_io/file_manager.py index 631d524ac..3b91e90bd 100644 --- a/erniebot-agent/src/erniebot_agent/file_io/file_manager.py +++ b/erniebot-agent/src/erniebot_agent/file_io/file_manager.py @@ -12,56 +12,71 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools import logging import os import pathlib import tempfile import uuid -import weakref -from typing import Any, Dict, List, Literal, Optional, Union, overload +from types import TracebackType +from typing import Any, Dict, List, Literal, Optional, Type, Union, final, overload import anyio -from typing_extensions import TypeAlias +from typing_extensions import Self, TypeAlias +from erniebot_agent.file_io import protocol from erniebot_agent.file_io.base import File -from erniebot_agent.file_io.file_registry import FileRegistry, get_file_registry +from erniebot_agent.file_io.caching import ( + FileCacheManager, + RemoteFileWithCache, + bind_cache_to_remote_file, + create_default_file_cache_manager, +) +from erniebot_agent.file_io.file_registry import FileRegistry from erniebot_agent.file_io.local_file import LocalFile, create_local_file_from_path -from erniebot_agent.file_io.protocol import FilePurpose from erniebot_agent.file_io.remote_file import RemoteFile, RemoteFileClient -from erniebot_agent.utils.exception import FileError +from erniebot_agent.utils.exceptions import FileError +from erniebot_agent.utils.mixins import Closeable logger = logging.getLogger(__name__) FilePath: TypeAlias = Union[str, os.PathLike] -class FileManager(object): - _remote_file_client: Optional[RemoteFileClient] +@final +class FileManager(Closeable): + _file_cache_manager: Optional[FileCacheManager] + _temp_dir: Optional[tempfile.TemporaryDirectory] = None def __init__( self, remote_file_client: Optional[RemoteFileClient] = None, *, - auto_register: bool = True, save_dir: Optional[FilePath] = None, + prune_on_close: bool = True, + cache_remote_files: bool = True, ) -> None: super().__init__() - if remote_file_client is not None: - self._remote_file_client = remote_file_client - else: - self._remote_file_client = None - self._auto_register = auto_register + + self._remote_file_client = remote_file_client if save_dir is not None: self._save_dir = pathlib.Path(save_dir) else: # This can be done lazily, but we need to be careful about race conditions. - self._save_dir = self._fs_create_temp_dir() + self._temp_dir = self._create_temp_dir() + self._save_dir = pathlib.Path(self._temp_dir.name) + self._prune_on_close = prune_on_close + self._cache_remote_files = cache_remote_files - self._file_registry = get_file_registry() + self._file_registry = FileRegistry() + if self._cache_remote_files: + self._file_cache_manager = create_default_file_cache_manager() + else: + self._file_cache_manager = None + self._fully_managed_files: List[Union[LocalFile, RemoteFile]] = [] - @property - def registry(self) -> FileRegistry: - return self._file_registry + self._closed = False + self._clean_up_cache_files_on_discard = True @property def remote_file_client(self) -> RemoteFileClient: @@ -70,14 +85,29 @@ def remote_file_client(self) -> RemoteFileClient: else: return self._remote_file_client + @property + def closed(self): + return self._closed + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + await self.close() + @overload async def create_file_from_path( self, file_path: FilePath, *, - file_purpose: FilePurpose = ..., + file_purpose: protocol.FilePurpose = ..., file_metadata: Optional[Dict[str, Any]] = ..., - file_type: Literal["local"] = ..., + file_type: Literal["local"], ) -> LocalFile: ... @@ -86,26 +116,35 @@ async def create_file_from_path( self, file_path: FilePath, *, - file_purpose: FilePurpose = ..., + file_purpose: protocol.FilePurpose = ..., file_metadata: Optional[Dict[str, Any]] = ..., file_type: Literal["remote"], ) -> RemoteFile: ... + @overload + async def create_file_from_path( + self, + file_path: FilePath, + *, + file_purpose: protocol.FilePurpose = ..., + file_metadata: Optional[Dict[str, Any]] = ..., + file_type: None = ..., + ) -> Union[LocalFile, RemoteFile]: + ... + async def create_file_from_path( self, file_path: FilePath, *, - file_purpose: FilePurpose = "assistants", + file_purpose: protocol.FilePurpose = "assistants", file_metadata: Optional[Dict[str, Any]] = None, file_type: Optional[Literal["local", "remote"]] = None, ) -> Union[LocalFile, RemoteFile]: + self.ensure_not_closed() file: Union[LocalFile, RemoteFile] if file_type is None: - if self._remote_file_client is not None: - file_type = "remote" - else: - file_type = "local" + file_type = self._get_default_file_type() if file_type == "local": file = await self.create_local_file_from_path(file_path, file_purpose, file_metadata) elif file_type == "remote": @@ -117,10 +156,10 @@ async def create_file_from_path( async def create_local_file_from_path( self, file_path: FilePath, - file_purpose: FilePurpose, + file_purpose: protocol.FilePurpose, file_metadata: Optional[Dict[str, Any]], ) -> LocalFile: - file = create_local_file_from_path( + file = await self._create_local_file_from_path( pathlib.Path(file_path), file_purpose, file_metadata or {}, @@ -129,13 +168,18 @@ async def create_local_file_from_path( return file async def create_remote_file_from_path( - self, file_path: FilePath, file_purpose: FilePurpose, file_metadata: Optional[Dict[str, Any]] + self, + file_path: FilePath, + file_purpose: protocol.FilePurpose, + file_metadata: Optional[Dict[str, Any]], ) -> RemoteFile: - file = await self.remote_file_client.upload_file( - pathlib.Path(file_path), file_purpose, file_metadata or {} + file = await self._create_remote_file_from_path( + pathlib.Path(file_path), + file_purpose, + file_metadata, ) - if self._auto_register: - self._file_registry.register_file(file) + self._file_registry.register_file(file) + self._fully_managed_files.append(file) return file @overload @@ -144,9 +188,9 @@ async def create_file_from_bytes( file_contents: bytes, filename: str, *, - file_purpose: FilePurpose = ..., + file_purpose: protocol.FilePurpose = ..., file_metadata: Optional[Dict[str, Any]] = ..., - file_type: Literal["local"] = ..., + file_type: Literal["local"], ) -> LocalFile: ... @@ -156,80 +200,202 @@ async def create_file_from_bytes( file_contents: bytes, filename: str, *, - file_purpose: FilePurpose = ..., + file_purpose: protocol.FilePurpose = ..., file_metadata: Optional[Dict[str, Any]] = ..., file_type: Literal["remote"], ) -> RemoteFile: ... + @overload async def create_file_from_bytes( self, file_contents: bytes, filename: str, *, - file_purpose: FilePurpose = "assistants", + file_purpose: protocol.FilePurpose = ..., + file_metadata: Optional[Dict[str, Any]] = ..., + file_type: None = ..., + ) -> Union[LocalFile, RemoteFile]: + ... + + async def create_file_from_bytes( + self, + file_contents: bytes, + filename: str, + *, + file_purpose: protocol.FilePurpose = "assistants", file_metadata: Optional[Dict[str, Any]] = None, file_type: Optional[Literal["local", "remote"]] = None, ) -> Union[LocalFile, RemoteFile]: - # Can we do this with in-memory files? - file_path = await self._fs_create_file( - prefix=pathlib.PurePath(filename).stem, suffix=pathlib.PurePath(filename).suffix + self.ensure_not_closed() + if file_type is None: + file_type = self._get_default_file_type() + file_path = self._get_unique_file_path( + prefix=pathlib.PurePath(filename).stem, + suffix=pathlib.PurePath(filename).suffix, ) + async_file_path = anyio.Path(file_path) + await async_file_path.touch() + should_remove_file = True try: - async with await file_path.open("wb") as f: + async with await async_file_path.open("wb") as f: await f.write(file_contents) - if file_type is None: - if self._remote_file_client is not None: - file_type = "remote" - else: - file_type = "local" - file = await self.create_file_from_path( - file_path, - file_purpose=file_purpose, - file_metadata=file_metadata, - file_type=file_type, - ) + file: Union[LocalFile, RemoteFile] + if file_type == "local": + file = await self._create_local_file_from_path(file_path, file_purpose, file_metadata) + should_remove_file = False + elif file_type == "remote": + file = await self._create_remote_file_from_path( + file_path, + file_purpose, + file_metadata, + cache_path=file_path if self._cache_remote_files else None, + init_cache_in_sync=True, + ) + if self._cache_remote_files: + if isinstance(file, RemoteFileWithCache): + cache_path = file.cache_path + if cache_path is not None and await async_file_path.samefile(cache_path): + should_remove_file = False + else: + raise ValueError(f"Unsupported file type: {file_type}") finally: - if file_type == "remote": - await file_path.unlink() + if should_remove_file: + await async_file_path.unlink() + self._file_registry.register_file(file) + self._fully_managed_files.append(file) return file async def retrieve_remote_file_by_id(self, file_id: str) -> RemoteFile: + self.ensure_not_closed() file = await self.remote_file_client.retrieve_file(file_id) - if self._auto_register: - self._file_registry.register_file(file, allow_overwrite=True) + if self._cache_remote_files: + file = await self._cache_remote_file( + file, + cache_path=None, + init_cache_in_sync=None, + ) + self._file_registry.register_file(file) return file + async def list_remote_files(self) -> List[RemoteFile]: + self.ensure_not_closed() + files = await self.remote_file_client.list_files() + return files + def look_up_file_by_id(self, file_id: str) -> Optional[File]: + self.ensure_not_closed() file = self._file_registry.look_up_file(file_id) if file is None: raise FileError( - f"File with ID '{file_id}' not found. " - "Please check if the file exists and the `file_id` is correct." + f"File with ID {repr(file_id)} not found. " + "Please check if `file_id` is correct and the file is registered." ) return file - async def list_remote_files(self) -> List[RemoteFile]: - files = await self.remote_file_client.list_files() - if self._auto_register: - for file in files: - self._file_registry.register_file(file, allow_overwrite=True) - return files + def list_registered_files(self) -> List[File]: + self.ensure_not_closed() + return self._file_registry.list_files() + + async def prune(self) -> None: + for file in self._fully_managed_files: + if isinstance(file, RemoteFile): + await file.delete() + if self._file_cache_manager is not None and isinstance(file, RemoteFileWithCache): + await self._file_cache_manager.remove_cache(file.id) + elif isinstance(file, LocalFile): + assert self._save_dir in file.path.parents + await anyio.Path(file.path).unlink() + else: + raise AssertionError("Unexpected file type") + self._file_registry.unregister_file(file) + self._fully_managed_files.clear() - async def _fs_create_file( + async def close(self) -> None: + if not self._closed: + if self._prune_on_close: + await self.prune() + if self._file_cache_manager is not None: + await self._file_cache_manager.close() + if self._temp_dir is not None: + self._clean_up_temp_dir(self._temp_dir) + self._closed = True + + async def _create_local_file_from_path( + self, + file_path: pathlib.Path, + file_purpose: protocol.FilePurpose, + file_metadata: Optional[Dict[str, Any]], + ) -> LocalFile: + return create_local_file_from_path( + pathlib.Path(file_path), + file_purpose, + file_metadata or {}, + ) + + async def _create_remote_file_from_path( + self, + file_path: pathlib.Path, + file_purpose: protocol.FilePurpose, + file_metadata: Optional[Dict[str, Any]], + *, + cache_path: Optional[pathlib.Path] = None, + init_cache_in_sync: Optional[bool] = None, + ) -> RemoteFile: + file = await self.remote_file_client.upload_file(file_path, file_purpose, file_metadata or {}) + if self._cache_remote_files: + file = await self._cache_remote_file( + file, cache_path=cache_path, init_cache_in_sync=init_cache_in_sync + ) + return file + + async def _cache_remote_file( + self, + file: RemoteFile, + *, + cache_path: Optional[pathlib.Path], + init_cache_in_sync: Optional[bool], + ) -> RemoteFileWithCache: + def _remove_cache_file(cache_path: pathlib.Path, logger: logging.Logger) -> None: + try: + cache_path.unlink(missing_ok=True) + except Exception as e: + logger.warning("Failed to remove cache file: %s", cache_path, exc_info=e) + + if self._file_cache_manager is None: + raise RuntimeError("Chaching is not enabled.") + if cache_path is None: + cache_path = self._get_unique_file_path() + init_cache_in_sync = None + if not cache_path.exists(): + await anyio.Path(cache_path).touch() + cache, _ = await self._file_cache_manager.get_or_create_cache( + file.id, + cache_path, + discard_callback=functools.partial(_remove_cache_file, pathlib.Path(cache_path), logger) + if self._clean_up_cache_files_on_discard + else None, + init_cache_in_sync=init_cache_in_sync, + ) + return bind_cache_to_remote_file(cache, file) + + def _get_default_file_type(self) -> Literal["local", "remote"]: + if self._remote_file_client is not None: + return "remote" + else: + return "local" + + def _get_unique_file_path( self, prefix: Optional[str] = None, suffix: Optional[str] = None - ) -> anyio.Path: + ) -> pathlib.Path: filename = f"{prefix or ''}{str(uuid.uuid4())}{suffix or ''}" - file_path = anyio.Path(self._save_dir / filename) - await file_path.touch() + file_path = self._save_dir / filename return file_path - def _fs_create_temp_dir(self) -> pathlib.Path: + @staticmethod + def _create_temp_dir() -> tempfile.TemporaryDirectory: temp_dir = tempfile.TemporaryDirectory() - # The temporary directory shall be cleaned up when the file manager is - # garbage collected. - weakref.finalize(self, self._clean_up_temp_dir, temp_dir) - return pathlib.Path(temp_dir.name) + return temp_dir @staticmethod def _clean_up_temp_dir(temp_dir: tempfile.TemporaryDirectory) -> None: diff --git a/erniebot-agent/src/erniebot_agent/file_io/file_registry.py b/erniebot-agent/src/erniebot_agent/file_io/file_registry.py index 9e21badc3..607562db8 100644 --- a/erniebot-agent/src/erniebot_agent/file_io/file_registry.py +++ b/erniebot-agent/src/erniebot_agent/file_io/file_registry.py @@ -12,44 +12,49 @@ # See the License for the specific language governing permissions and # limitations under the License. -import threading -from typing import Dict, List, Optional +from typing import Dict, List, Optional, final from erniebot_agent.file_io.base import File -from erniebot_agent.utils.misc import Singleton -class FileRegistry(metaclass=Singleton): +class BaseFileRegistry(object): + def register_file(self, file: File, *, allow_overwrite: bool = False, check_type: bool = True) -> None: + raise NotImplementedError + + def unregister_file(self, file: File) -> None: + raise NotImplementedError + + def look_up_file(self, file_id: str) -> Optional[File]: + raise NotImplementedError + + def list_files(self) -> List[File]: + raise NotImplementedError + + +@final +class FileRegistry(BaseFileRegistry): def __init__(self) -> None: super().__init__() self._id_to_file: Dict[str, File] = {} - self._lock = threading.Lock() - def register_file(self, file: File, *, allow_overwrite: bool = False) -> None: + def register_file(self, file: File, *, allow_overwrite: bool = False, check_type: bool = True) -> None: file_id = file.id - with self._lock: - if not allow_overwrite and file_id in self._id_to_file: - raise RuntimeError(f"ID {repr(file_id)} is already registered.") - self._id_to_file[file_id] = file + if file_id in self._id_to_file: + if not allow_overwrite: + raise RuntimeError(f"File with ID {repr(file_id)} is already registered.") + else: + if check_type and type(file) is not type(self._id_to_file[file_id]): # noqa: E721 + raise RuntimeError("Cannot register a file with a different type.") + self._id_to_file[file_id] = file def unregister_file(self, file: File) -> None: file_id = file.id - with self._lock: - if file_id not in self._id_to_file: - raise RuntimeError(f"ID {repr(file_id)} is not registered.") - self._id_to_file.pop(file_id) + if file_id not in self._id_to_file: + raise RuntimeError(f"File with ID {repr(file_id)} is not registered.") + self._id_to_file.pop(file_id) def look_up_file(self, file_id: str) -> Optional[File]: - with self._lock: - return self._id_to_file.get(file_id, None) + return self._id_to_file.get(file_id, None) def list_files(self) -> List[File]: - with self._lock: - return list(self._id_to_file.values()) - - -_file_registry = FileRegistry() - - -def get_file_registry() -> FileRegistry: - return _file_registry + return list(self._id_to_file.values()) diff --git a/erniebot-agent/src/erniebot_agent/file_io/local_file.py b/erniebot-agent/src/erniebot_agent/file_io/local_file.py index 9b2b4ef5e..b992105b0 100644 --- a/erniebot-agent/src/erniebot_agent/file_io/local_file.py +++ b/erniebot-agent/src/erniebot_agent/file_io/local_file.py @@ -13,18 +13,36 @@ # limitations under the License. import pathlib -import time import uuid from typing import Any, Dict import anyio +from erniebot_agent.file_io import protocol from erniebot_agent.file_io.base import File -from erniebot_agent.file_io.protocol import ( - FilePurpose, - build_local_file_id_from_uuid, - is_local_file_id, -) + + +def create_local_file_from_path( + file_path: pathlib.Path, + file_purpose: protocol.FilePurpose, + file_metadata: Dict[str, Any], +) -> "LocalFile": + if not file_path.exists(): + raise FileNotFoundError(f"File {file_path} does not exist.") + file_id = _generate_local_file_id() + filename = file_path.name + byte_size = file_path.stat().st_size + created_at = protocol.get_timestamp() + file = LocalFile( + id=file_id, + filename=filename, + byte_size=byte_size, + created_at=created_at, + purpose=file_purpose, + metadata=file_metadata, + path=file_path, + ) + return file class LocalFile(File): @@ -34,13 +52,15 @@ def __init__( id: str, filename: str, byte_size: int, - created_at: int, - purpose: FilePurpose, + created_at: str, + purpose: protocol.FilePurpose, metadata: Dict[str, Any], path: pathlib.Path, + validate_file_id: bool = True, ) -> None: - if not is_local_file_id(id): - raise ValueError(f"Invalid file ID: {id}") + if validate_file_id: + if not protocol.is_local_file_id(id): + raise ValueError(f"Invalid file ID: {id}") super().__init__( id=id, filename=filename, @@ -60,28 +80,5 @@ def _get_attrs_str(self) -> str: return attrs_str -def create_local_file_from_path( - file_path: pathlib.Path, - file_purpose: FilePurpose, - file_metadata: Dict[str, Any], -) -> LocalFile: - if not file_path.exists(): - raise FileNotFoundError(f"File {file_path} does not exist.") - file_id = _generate_local_file_id() - filename = file_path.name - byte_size = file_path.stat().st_size - created_at = int(time.time()) - file = LocalFile( - id=file_id, - filename=filename, - byte_size=byte_size, - created_at=created_at, - purpose=file_purpose, - metadata=file_metadata, - path=file_path, - ) - return file - - def _generate_local_file_id(): - return build_local_file_id_from_uuid(str(uuid.uuid1())) + return protocol.create_local_file_id_from_uuid(str(uuid.uuid1())) diff --git a/erniebot-agent/src/erniebot_agent/file_io/protocol.py b/erniebot-agent/src/erniebot_agent/file_io/protocol.py index 4b6e28d5d..f636dfe0f 100644 --- a/erniebot-agent/src/erniebot_agent/file_io/protocol.py +++ b/erniebot-agent/src/erniebot_agent/file_io/protocol.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime import re from typing import List, Literal @@ -28,10 +29,14 @@ _compiled_remote_file_id_pattern = re.compile(_REMOTE_FILE_ID_PATTERN) -def build_local_file_id_from_uuid(uuid: str) -> str: +def create_local_file_id_from_uuid(uuid: str) -> str: return _LOCAL_FILE_ID_PREFIX + uuid +def get_timestamp() -> str: + return datetime.datetime.now().isoformat(sep=" ", timespec="seconds") + + def is_file_id(str_: str) -> bool: return is_local_file_id(str_) or is_remote_file_id(str_) diff --git a/erniebot-agent/src/erniebot_agent/file_io/remote_file.py b/erniebot-agent/src/erniebot_agent/file_io/remote_file.py index f48335350..02d499173 100644 --- a/erniebot-agent/src/erniebot_agent/file_io/remote_file.py +++ b/erniebot-agent/src/erniebot_agent/file_io/remote_file.py @@ -21,9 +21,10 @@ import aiohttp +from erniebot_agent.file_io import protocol from erniebot_agent.file_io.base import File -from erniebot_agent.file_io.protocol import FilePurpose, is_remote_file_id -from erniebot_agent.utils.exception import FileError +from erniebot_agent.utils.exceptions import FileError +from erniebot_agent.utils.mixins import Closeable class RemoteFile(File): @@ -33,13 +34,15 @@ def __init__( id: str, filename: str, byte_size: int, - created_at: int, - purpose: FilePurpose, + created_at: str, + purpose: protocol.FilePurpose, metadata: Dict[str, Any], client: "RemoteFileClient", + validate_file_id: bool = True, ) -> None: - if not is_remote_file_id(id): - raise FileError(f"Invalid file ID: {id}") + if validate_file_id: + if not protocol.is_remote_file_id(id): + raise FileError(f"Invalid file ID: {id}") super().__init__( id=id, filename=filename, @@ -50,6 +53,10 @@ def __init__( ) self._client = client + @property + def client(self) -> "RemoteFileClient": + return self._client + async def read_contents(self) -> bytes: file_contents = await self._client.retrieve_file_contents(self.id) return file_contents @@ -64,10 +71,10 @@ def get_file_repr_with_url(self, url: str) -> str: return f"{self.get_file_repr()}{url}" -class RemoteFileClient(metaclass=abc.ABCMeta): +class RemoteFileClient(Closeable, metaclass=abc.ABCMeta): @abc.abstractmethod async def upload_file( - self, file_path: pathlib.Path, file_purpose: FilePurpose, file_metadata: Dict[str, Any] + self, file_path: pathlib.Path, file_purpose: protocol.FilePurpose, file_metadata: Dict[str, Any] ) -> RemoteFile: raise NotImplementedError @@ -104,11 +111,19 @@ def __init__( ) -> None: super().__init__() self._access_token = access_token + if aiohttp_session is None: + aiohttp_session = self._create_aiohttp_session() self._session = aiohttp_session + self._closed = False + + @property + def closed(self) -> bool: + return self._closed async def upload_file( - self, file_path: pathlib.Path, file_purpose: FilePurpose, file_metadata: Dict[str, Any] + self, file_path: pathlib.Path, file_purpose: protocol.FilePurpose, file_metadata: Dict[str, Any] ) -> RemoteFile: + self.ensure_not_closed() url = self._get_url(self._UPLOAD_ENDPOINT) headers: Dict[str, str] = {} headers.update(self._get_default_headers()) @@ -125,9 +140,10 @@ async def upload_file( raise_for_status=True, ) result = self._get_result_from_response_body(resp_bytes) - return self._build_file_obj_from_dict(result) + return self._create_file_obj_from_dict(result) async def retrieve_file(self, file_id: str) -> RemoteFile: + self.ensure_not_closed() url = self._get_url(self._RETRIEVE_ENDPOINT).format(file_id=file_id) headers: Dict[str, str] = {} headers.update(self._get_default_headers()) @@ -138,9 +154,10 @@ async def retrieve_file(self, file_id: str) -> RemoteFile: raise_for_status=True, ) result = self._get_result_from_response_body(resp_bytes) - return self._build_file_obj_from_dict(result) + return self._create_file_obj_from_dict(result) async def retrieve_file_contents(self, file_id: str) -> bytes: + self.ensure_not_closed() url = self._get_url(self._RETRIEVE_CONTENTS_ENDPOINT).format(file_id=file_id) headers: Dict[str, str] = {} headers.update(self._get_default_headers()) @@ -153,6 +170,7 @@ async def retrieve_file_contents(self, file_id: str) -> bytes: return resp_bytes async def list_files(self) -> List[RemoteFile]: + self.ensure_not_closed() url = self._get_url(self._LIST_ENDPOINT) headers: Dict[str, str] = {} headers.update(self._get_default_headers()) @@ -165,7 +183,7 @@ async def list_files(self) -> List[RemoteFile]: result = self._get_result_from_response_body(resp_bytes) files: List[RemoteFile] = [] for item in result: - file = self._build_file_obj_from_dict(item) + file = self._create_file_obj_from_dict(item) files.append(file) return files @@ -186,14 +204,12 @@ async def create_temporary_url(self, file_id: str, expire_after: float) -> str: result = self._get_result_from_response_body(resp_bytes) return result["fileUrl"] - async def _request(self, *args: Any, **kwargs: Any) -> bytes: - if self._session is not None: - async with self._session.request(*args, **kwargs) as response: - return await response.read() - else: - async with aiohttp.ClientSession(**self._get_session_config()) as session: - async with session.request(*args, **kwargs) as response: - return await response.read() + async def close(self) -> None: + if not self._closed: + await self._session.close() + + def _create_aiohttp_session(self) -> aiohttp.ClientSession: + return aiohttp.ClientSession(**self._get_session_config()) def _get_session_config(self) -> Dict[str, Any]: return {} @@ -203,9 +219,13 @@ def _get_default_headers(self) -> Dict[str, str]: "Authorization": f"token {self._access_token}", } - def _build_file_obj_from_dict(self, dict_: Dict[str, Any]) -> RemoteFile: + async def _request(self, *args: Any, **kwargs: Any) -> bytes: + async with self._session.request(*args, **kwargs) as response: + return await response.read() + + def _create_file_obj_from_dict(self, dict_: Dict[str, Any]) -> RemoteFile: metadata: Dict[str, Any] - if "meta" in dict_: + if dict_.get("meta"): metadata = json.loads(dict_["meta"]) if not isinstance(metadata, dict): raise FileError(f"Invalid metadata: {dict_['meta']}") diff --git a/erniebot-agent/src/erniebot_agent/retrieval/baizhong_search.py b/erniebot-agent/src/erniebot_agent/retrieval/baizhong_search.py index b68a3ef7f..b984491bf 100644 --- a/erniebot-agent/src/erniebot_agent/retrieval/baizhong_search.py +++ b/erniebot-agent/src/erniebot_agent/retrieval/baizhong_search.py @@ -5,7 +5,7 @@ import requests -from erniebot_agent.utils.exception import BaizhongError +from erniebot_agent.utils.exceptions import BaizhongError logger = logging.getLogger(__name__) diff --git a/erniebot-agent/src/erniebot_agent/tools/remote_tool.py b/erniebot-agent/src/erniebot_agent/tools/remote_tool.py index 80a32d1e7..5096188e2 100644 --- a/erniebot-agent/src/erniebot_agent/tools/remote_tool.py +++ b/erniebot-agent/src/erniebot_agent/tools/remote_tool.py @@ -18,7 +18,7 @@ tool_response_contains_file, ) from erniebot_agent.utils.common import is_json_response -from erniebot_agent.utils.exception import RemoteToolError +from erniebot_agent.utils.exceptions import RemoteToolError from erniebot_agent.utils.logging import logger @@ -183,7 +183,6 @@ async def send_request(self, tool_arguments: Dict[str, Any]) -> dict: raise RemoteToolError( f"Unsupported content type: {self.tool_view.parameters_content_type}", stage="Executing" ) - if self.tool_view.method == "get": response = requests.get(url, **requests_inputs) # type: ignore elif self.tool_view.method == "post": diff --git a/erniebot-agent/src/erniebot_agent/tools/remote_toolkit.py b/erniebot-agent/src/erniebot_agent/tools/remote_toolkit.py index d90d04f86..19eea4544 100644 --- a/erniebot-agent/src/erniebot_agent/tools/remote_toolkit.py +++ b/erniebot-agent/src/erniebot_agent/tools/remote_toolkit.py @@ -2,6 +2,7 @@ import copy import json +import logging import os import tempfile from dataclasses import asdict, dataclass, field @@ -12,7 +13,7 @@ from openapi_spec_validator.readers import read_from_filename from yaml import safe_dump -from erniebot_agent.file_io import get_file_manager +from erniebot_agent.file_io import get_global_file_manager from erniebot_agent.file_io.file_manager import FileManager from erniebot_agent.messages import AIMessage, FunctionCall, HumanMessage, Message from erniebot_agent.tools.remote_tool import RemoteTool, tool_registor @@ -24,9 +25,10 @@ scrub_dict, ) from erniebot_agent.tools.utils import validate_openapi_yaml -from erniebot_agent.utils.exception import RemoteToolError +from erniebot_agent.utils.exceptions import RemoteToolError from erniebot_agent.utils.http import url_file_exists -from erniebot_agent.utils.logging import logger + +logger = logging.getLogger(__name__) @dataclass @@ -125,7 +127,7 @@ def get_tool(self, tool_name: str) -> RemoteTool: TOOL_CLASS = tool_registor.get_tool_class(self.info.title) return TOOL_CLASS( - paths[0], + copy.deepcopy(paths[0]), self.servers[0].url, self.headers, self.info.version, @@ -192,7 +194,7 @@ def from_openapi_dict( ) if file_manager is None: - file_manager = get_file_manager(access_token) + file_manager = get_global_file_manager(access_token=access_token) return RemoteToolkit( openapi=openapi_dict["openapi"], diff --git a/erniebot-agent/src/erniebot_agent/tools/schema.py b/erniebot-agent/src/erniebot_agent/tools/schema.py index 32cdec1ea..aa22bb052 100644 --- a/erniebot-agent/src/erniebot_agent/tools/schema.py +++ b/erniebot-agent/src/erniebot_agent/tools/schema.py @@ -24,7 +24,7 @@ from pydantic.fields import FieldInfo from erniebot_agent.utils.common import create_enum_class -from erniebot_agent.utils.exception import RemoteToolError +from erniebot_agent.utils.exceptions import RemoteToolError INVALID_FIELD_NAME = "__invalid_field_name__" diff --git a/erniebot-agent/src/erniebot_agent/tools/utils.py b/erniebot-agent/src/erniebot_agent/tools/utils.py index d7e062106..4e91a1a5b 100644 --- a/erniebot-agent/src/erniebot_agent/tools/utils.py +++ b/erniebot-agent/src/erniebot_agent/tools/utils.py @@ -18,7 +18,7 @@ get_typing_list_type, ) from erniebot_agent.utils.common import get_file_suffix, is_json_response -from erniebot_agent.utils.exception import RemoteToolError +from erniebot_agent.utils.exceptions import RemoteToolError from erniebot_agent.utils.logging import logger diff --git a/erniebot-agent/src/erniebot_agent/utils/exception.py b/erniebot-agent/src/erniebot_agent/utils/exceptions.py similarity index 95% rename from erniebot-agent/src/erniebot_agent/utils/exception.py rename to erniebot-agent/src/erniebot_agent/utils/exceptions.py index 3ca029cd2..4e7475927 100644 --- a/erniebot-agent/src/erniebot_agent/utils/exception.py +++ b/erniebot-agent/src/erniebot_agent/utils/exceptions.py @@ -36,3 +36,7 @@ def __init__(self, message: str): def __str__(self): return self.message + + +class ObjectClosedError(Exception): + pass diff --git a/erniebot-agent/src/erniebot_agent/utils/gradio_mixin.py b/erniebot-agent/src/erniebot_agent/utils/mixins.py similarity index 85% rename from erniebot-agent/src/erniebot_agent/utils/gradio_mixin.py rename to erniebot-agent/src/erniebot_agent/utils/mixins.py index 38d56ace8..8e3d5fedf 100644 --- a/erniebot-agent/src/erniebot_agent/utils/gradio_mixin.py +++ b/erniebot-agent/src/erniebot_agent/utils/mixins.py @@ -1,14 +1,33 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + import base64 import os import tempfile -from typing import Any, List +from typing import TYPE_CHECKING, Any, List, Protocol -from erniebot_agent.file_io.base import File -from erniebot_agent.file_io.file_manager import FileManager -from erniebot_agent.tools.tool_manager import ToolManager from erniebot_agent.utils.common import get_file_type +from erniebot_agent.utils.exceptions import ObjectClosedError from erniebot_agent.utils.html_format import IMAGE_HTML, ITEM_LIST_HTML +if TYPE_CHECKING: + from erniebot_agent.file_io.base import File + from erniebot_agent.file_io.file_manager import FileManager + from erniebot_agent.tools.tool_manager import ToolManager + class GradioMixin: _file_manager: FileManager # make mypy happy @@ -93,7 +112,7 @@ def _clear(): self.reset_memory() return None, None, None, None - async def _upload(file: List[gr.utils.NamedString], history: list): + async def _upload(file, history): nonlocal _uploaded_file_cache for single_file in file: upload_file = await self._file_manager.create_file_from_path(single_file.name) @@ -101,7 +120,7 @@ async def _upload(file: List[gr.utils.NamedString], history: list): history = history + [((single_file.name,), None)] size = len(file) - output_lis = self._file_manager.registry.list_files() + output_lis = self._file_manager.list_registered_files() item = "" for i in range(len(output_lis) - size): item += f'
  • {str(output_lis[i]).strip("<>")}
  • ' @@ -146,7 +165,7 @@ def _messages_to_dicts(messages): ) with gr.Accordion("Files", open=False): - file_lis = self._file_manager.registry.list_files() + file_lis = self._file_manager.list_registered_files() all_files = gr.HTML(value=file_lis, label="All input files") with gr.Accordion("Tools", open=False): attached_tools = self._tool_manager.get_tools() @@ -209,3 +228,16 @@ def _messages_to_dicts(messages): else: allowed_paths = [td] demo.launch(allowed_paths=allowed_paths, **launch_kwargs) + + +class Closeable(Protocol): + @property + def closed(self) -> bool: + ... + + async def close(self) -> None: + ... + + def ensure_not_closed(self) -> None: + if self.closed: + raise ObjectClosedError(f"{repr(self)} is closed.") diff --git a/erniebot-agent/src/erniebot_agent/utils/temp_file.py b/erniebot-agent/src/erniebot_agent/utils/temp_file.py new file mode 100644 index 000000000..ecb98ac43 --- /dev/null +++ b/erniebot-agent/src/erniebot_agent/utils/temp_file.py @@ -0,0 +1,39 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import atexit +import logging +import pathlib +import tempfile +from tempfile import TemporaryDirectory +from typing import Any, List + +logger = logging.getLogger(__name__) + +_tracked_temp_dirs: List[TemporaryDirectory] = [] + + +def create_tracked_temp_dir(*args: Any, **kwargs: Any) -> pathlib.Path: + # Borrowed from + # https://github.com/pypa/pipenv/blob/247a14369d300a6980a8dd634d9060bf6f582d2d/pipenv/utils/fileutils.py#L197 + def _cleanup() -> None: + try: + temp_dir.cleanup() + except Exception as e: + logger.warning("Failed to clean up temporary directory: %s", temp_dir.name, exc_info=e) + + temp_dir = tempfile.TemporaryDirectory(*args, **kwargs) + _tracked_temp_dirs.append(temp_dir) + atexit.register(_cleanup) + return pathlib.Path(temp_dir.name) diff --git a/erniebot-agent/tests/integration_tests/apihub/base.py b/erniebot-agent/tests/integration_tests/apihub/base.py index a7384efb8..1a8c92812 100644 --- a/erniebot-agent/tests/integration_tests/apihub/base.py +++ b/erniebot-agent/tests/integration_tests/apihub/base.py @@ -10,19 +10,20 @@ from erniebot_agent.agents.functional_agent import FunctionalAgent from erniebot_agent.chat_models import ERNIEBot -from erniebot_agent.file_io import get_file_manager +from erniebot_agent.file_io.file_manager import FileManager from erniebot_agent.memory import WholeMemory from erniebot_agent.tools import RemoteToolkit from erniebot_agent.tools.tool_manager import ToolManager class RemoteToolTesting(unittest.IsolatedAsyncioTestCase): - def setUp(self) -> None: + async def asyncSetUp(self) -> None: self.temp_dir = tempfile.mkdtemp() - self.file_manager = get_file_manager() + self.file_manager = FileManager() - def tearDown(self) -> None: + async def asyncTearDown(self) -> None: shutil.rmtree(self.temp_dir) + await self.file_manager.close() def download_file(self, url, file_name: Optional[str] = None): image_response = requests.get(url) diff --git a/erniebot-agent/tests/integration_tests/apihub/test_doc_analysis.py b/erniebot-agent/tests/integration_tests/apihub/test_doc_analysis.py index 7b38f282c..b99342a3b 100644 --- a/erniebot-agent/tests/integration_tests/apihub/test_doc_analysis.py +++ b/erniebot-agent/tests/integration_tests/apihub/test_doc_analysis.py @@ -1,7 +1,5 @@ from __future__ import annotations -import asyncio - import pytest from erniebot_agent.tools.remote_toolkit import RemoteToolkit @@ -10,11 +8,9 @@ class TestRemoteTool(RemoteToolTesting): - def setUp(self) -> None: - super().setUp() - self.file = asyncio.run( - self.file_manager.create_file_from_path(self.download_fixture_file("城市管理执法办法.pdf")) - ) + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.file = await self.file_manager.create_file_from_path(self.download_fixture_file("城市管理执法办法.pdf")) @pytest.mark.asyncio async def test_doc_analysis(self): diff --git a/erniebot-agent/tests/integration_tests/apihub/test_handwriting.py b/erniebot-agent/tests/integration_tests/apihub/test_handwriting.py index 58eba1471..3d3d1c021 100644 --- a/erniebot-agent/tests/integration_tests/apihub/test_handwriting.py +++ b/erniebot-agent/tests/integration_tests/apihub/test_handwriting.py @@ -1,7 +1,5 @@ from __future__ import annotations -import asyncio - import pytest from erniebot_agent.tools.remote_toolkit import RemoteToolkit @@ -10,10 +8,10 @@ class TestRemoteTool(RemoteToolTesting): - def setUp(self) -> None: - super().setUp() - self.file = asyncio.run( - self.file_manager.create_file_from_path(self.download_fixture_file("shouxiezi.png")) + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.file = await self.file_manager.create_file_from_path( + self.download_fixture_file("shouxiezi.png") ) @pytest.mark.asyncio diff --git a/erniebot-agent/tests/integration_tests/apihub/test_img_transform.py b/erniebot-agent/tests/integration_tests/apihub/test_img_transform.py index 4115d36ca..7482cacb4 100644 --- a/erniebot-agent/tests/integration_tests/apihub/test_img_transform.py +++ b/erniebot-agent/tests/integration_tests/apihub/test_img_transform.py @@ -1,7 +1,5 @@ from __future__ import annotations -import asyncio - import pytest from erniebot_agent.tools.remote_toolkit import RemoteToolkit @@ -10,11 +8,9 @@ class TestRemoteTool(RemoteToolTesting): - def setUp(self) -> None: - super().setUp() - self.file = asyncio.run( - self.file_manager.create_file_from_path(self.download_fixture_file("trans.png")) - ) + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.file = await self.file_manager.create_file_from_path(self.download_fixture_file("trans.png")) @pytest.mark.asyncio async def test_img_style_trans(self): diff --git a/erniebot-agent/tests/integration_tests/apihub/test_ocr.py b/erniebot-agent/tests/integration_tests/apihub/test_ocr.py index 4e1c199cb..3199a1164 100644 --- a/erniebot-agent/tests/integration_tests/apihub/test_ocr.py +++ b/erniebot-agent/tests/integration_tests/apihub/test_ocr.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio import json import pytest @@ -11,10 +10,10 @@ class TestRemoteTool(RemoteToolTesting): - def setUp(self) -> None: - super().setUp() - self.file = asyncio.run( - self.file_manager.create_file_from_path(self.download_fixture_file("ocr_table.png")) + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.file = await self.file_manager.create_file_from_path( + self.download_fixture_file("ocr_table.png") ) @pytest.mark.asyncio diff --git a/erniebot-agent/tests/unit_tests/agents/callback/test_callback_manager.py b/erniebot-agent/tests/unit_tests/agents/callback/test_callback_manager.py index 7f3b30fb0..03f6a715f 100644 --- a/erniebot-agent/tests/unit_tests/agents/callback/test_callback_manager.py +++ b/erniebot-agent/tests/unit_tests/agents/callback/test_callback_manager.py @@ -13,7 +13,7 @@ @pytest.mark.asyncio async def test_callback_manager_hit(): - def _assert_num_calls(handler): + def _assert_all_counts(handler): assert handler.run_starts == 1 assert handler.llm_starts == 1 assert handler.llm_ends == 1 @@ -27,27 +27,40 @@ def _assert_num_calls(handler): llm = mock.Mock(spec=ChatModel) tool = mock.Mock(spec=Tool) - handler1 = CountingCallbackHandler() - handler2 = CountingCallbackHandler() - callback_manager = CallbackManager(handlers=[handler1, handler2]) + handler = CountingCallbackHandler() + callback_manager = CallbackManager(handlers=[handler]) await callback_manager.on_run_start(agent, "") + assert handler.run_starts == 1 + await callback_manager.on_llm_start(agent, llm, []) + assert handler.llm_starts == 1 + await callback_manager.on_llm_end( agent, llm, AIMessage(content="", function_call=None, token_usage={"prompt_tokens": 0, "completion_tokens": 0}), ) + assert handler.llm_ends == 1 + await callback_manager.on_llm_error(agent, llm, Exception()) + assert handler.llm_errors == 1 + await callback_manager.on_tool_start(agent, tool, "{}") + assert handler.tool_starts == 1 + await callback_manager.on_tool_end(agent, tool, "{}") + assert handler.tool_ends == 1 + await callback_manager.on_tool_error(agent, tool, Exception()) + assert handler.tool_errors == 1 + await callback_manager.on_run_end( agent, AgentResponse(text="", chat_history=[], actions=[], files=[], status="FINISHED") ) + assert handler.run_ends == 1 - _assert_num_calls(handler1) - _assert_num_calls(handler2) + _assert_all_counts(handler) @pytest.mark.asyncio @@ -55,14 +68,20 @@ async def test_callback_manager_add_remove_handlers(): handler1 = CountingCallbackHandler() handler2 = CountingCallbackHandler() callback_manager = CallbackManager(handlers=[handler1]) + assert len(callback_manager.handlers) == 1 + with pytest.raises(RuntimeError): callback_manager.add_handler(handler1) + callback_manager.remove_handler(handler1) assert len(callback_manager.handlers) == 0 + callback_manager.add_handler(handler1) assert len(callback_manager.handlers) == 1 + callback_manager.add_handler(handler2) assert len(callback_manager.handlers) == 2 + callback_manager.remove_all_handlers() assert len(callback_manager.handlers) == 0 diff --git a/erniebot-agent/tests/unit_tests/agents/test_agent_response_annotations.py b/erniebot-agent/tests/unit_tests/agents/test_agent_response_annotations.py index 6253327ae..eced30299 100644 --- a/erniebot-agent/tests/unit_tests/agents/test_agent_response_annotations.py +++ b/erniebot-agent/tests/unit_tests/agents/test_agent_response_annotations.py @@ -1,17 +1,19 @@ -import asyncio import unittest from typing import List, Literal from erniebot_agent.agents.schema import AgentFile, AgentResponse -from erniebot_agent.file_io import get_file_manager +from erniebot_agent.file_io.file_manager import FileManager -class TestAgentResponseAnnotations(unittest.TestCase): - def setUp(self): +class TestAgentResponseAnnotations(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): self.test = "" - self.file_manager = get_file_manager() - self.file1 = asyncio.run(self.file_manager.create_file_from_bytes(b"test1", "test1.txt")) - self.file2 = asyncio.run(self.file_manager.create_file_from_bytes(b"test2", "test2.txt")) + self.file_manager = FileManager() + self.file1 = await self.file_manager.create_file_from_bytes(b"test1", "test1.txt") + self.file2 = await self.file_manager.create_file_from_bytes(b"test2", "test2.txt") + + async def asyncTearDown(self): + await self.file_manager.close() def test_agent_response_onefile_oneagentfile(self): agent_file = AgentFile(file=self.file1, type="input", used_by="") diff --git a/erniebot-agent/tests/unit_tests/agents/test_functional_agent.py b/erniebot-agent/tests/unit_tests/agents/test_functional_agent.py index d3a886aa4..b6ef7e4f6 100644 --- a/erniebot-agent/tests/unit_tests/agents/test_functional_agent.py +++ b/erniebot-agent/tests/unit_tests/agents/test_functional_agent.py @@ -97,17 +97,22 @@ async def test_functional_agent_load_unload_tools(identity_tool, no_input_no_out @pytest.mark.asyncio -async def test_functional_agent_run_llm(identity_tool): +async def test_functional_agent_run_llm_return_text(): output_message = AIMessage("Hello!", function_call=None) agent = FunctionalAgent( llm=FakeChatModelWithPresetResponses(responses=[output_message]), tools=[], memory=FakeMemory(), ) + llm_response = await agent._async_run_llm(messages=[HumanMessage("Hello, world!")]) + assert isinstance(llm_response.message, AIMessage) assert llm_response.message == output_message + +@pytest.mark.asyncio +async def test_functional_agent_run_llm_return_function_call(identity_tool): output_message = AIMessage( "", function_call=FunctionCall( @@ -119,7 +124,9 @@ async def test_functional_agent_run_llm(identity_tool): tools=[identity_tool], memory=FakeMemory(), ) + llm_response = await agent._async_run_llm(messages=[HumanMessage("Hello, world!")]) + assert isinstance(llm_response.message, AIMessage) assert llm_response.message == output_message @@ -161,24 +168,6 @@ async def test_functional_agent_memory(identity_tool): AIMessage("", function_call=function_call), AIMessage("", function_call=function_call), AIMessage(output_text, function_call=None), - ] - ) - agent = FunctionalAgent( - llm=llm, - tools=[identity_tool], - memory=FakeMemory(), - ) - await agent.async_run(input_text) - messages_in_memory = agent.memory.get_messages() - assert len(messages_in_memory) == 2 - assert isinstance(messages_in_memory[0], HumanMessage) - assert messages_in_memory[0].content == input_text - assert isinstance(messages_in_memory[1], AIMessage) - assert messages_in_memory[1].content == output_text - - llm = FakeChatModelWithPresetResponses( - responses=[ - AIMessage(output_text, function_call=None), AIMessage(output_text, function_call=None), AIMessage("This message should not be remembered.", function_call=None), AIMessage("This message should not be remembered, either.", function_call=None), @@ -189,6 +178,7 @@ async def test_functional_agent_memory(identity_tool): tools=[identity_tool], memory=FakeMemory(), ) + await agent.async_run(input_text) messages_in_memory = agent.memory.get_messages() assert len(messages_in_memory) == 2 @@ -196,6 +186,7 @@ async def test_functional_agent_memory(identity_tool): assert messages_in_memory[0].content == input_text assert isinstance(messages_in_memory[1], AIMessage) assert messages_in_memory[1].content == output_text + await agent.async_run(input_text) assert len(agent.memory.get_messages()) == 2 + 2 agent.reset_memory() @@ -223,5 +214,7 @@ async def test_functional_agent_max_steps(identity_tool): memory=FakeMemory(), max_steps=2, ) + response = await agent.async_run("Run!") + assert response.status == "STOPPED" diff --git a/erniebot-agent/tests/unit_tests/file_io/test_caching.py b/erniebot-agent/tests/unit_tests/file_io/test_caching.py new file mode 100644 index 000000000..5f73f541b --- /dev/null +++ b/erniebot-agent/tests/unit_tests/file_io/test_caching.py @@ -0,0 +1,399 @@ +import asyncio +import contextlib +import copy +import os +import pathlib +import tempfile +from unittest import mock + +import pytest +from tests.unit_tests.testing_utils.mocks.mock_remote_file_client_server import ( + FakeRemoteFileClient, + FakeRemoteFileServer, +) + +from erniebot_agent.file_io.caching import ( + CacheDiscardedError, + CacheNotFoundError, + FileCache, + FileCacheManager, +) +from erniebot_agent.utils.exceptions import ObjectClosedError + + +@contextlib.asynccontextmanager +async def create_file_cache_manager(cache_factory=None): + if cache_factory is None: + cache_factory = create_file_cache + manager = FileCacheManager(cache_factory=cache_factory) + + yield manager + + await manager.close() + + +def create_file_cache(cache_path, active, discard_callback=None, expire_after=None): + return FileCache( + cache_path=cache_path, + active=active, + discard_callback=discard_callback, + expire_after=expire_after, + ) + + +def repeat_bytes_in_coro_func(bytes_): + async def _get_bytes(): + return bytes_ + + return _get_bytes + + +@contextlib.contextmanager +def create_temporary_file(): + fd, path = tempfile.mkstemp() + os.close(fd) + try: + yield pathlib.Path(path) + finally: + if os.path.exists(path): + os.unlink(path) + + +class AwaitableContents(object): + def __init__(self, contents): + super().__init__() + self.contents = contents + + def __await__(self): + yield + return self.contents + + +@pytest.mark.asyncio +async def test_file_cache_manager_crd(): + server = FakeRemoteFileServer() + + with server.start(): + client = FakeRemoteFileClient(server) + + with client.protocol.follow(): + with create_temporary_file() as file_path, create_temporary_file() as cache_path: + with open(file_path, "wb") as f: + f.write(b"Simple is better than complex.") + file = await client.upload_file(file_path, "assistants", {}) + + async with create_file_cache_manager() as manager: + cache, created = await manager.get_or_create_cache(file.id, cache_path=cache_path) + assert cache.cache_path.samefile(cache_path) + assert created + + retrieved_cache, created = await manager.get_or_create_cache( + file.id, cache_path=cache_path + ) + assert retrieved_cache is cache + assert not created + + retrieved_cache = await manager.get_cache(file.id) + assert retrieved_cache is cache + + await manager.remove_cache(file.id) + with pytest.raises(CacheNotFoundError): + await manager.get_cache(file.id) + + +@pytest.mark.asyncio +async def test_file_cache_manager_close(): + server = FakeRemoteFileServer() + + with server.start(): + client = FakeRemoteFileClient(server) + + with client.protocol.follow(): + with create_temporary_file() as file_path, create_temporary_file() as cache_path: + with open(file_path, "wb") as f: + f.write(b"Simple is better than complex.") + file = await client.upload_file(file_path, "assistants", {}) + + async with create_file_cache_manager() as manager: + cache, _ = await manager.get_or_create_cache(file.id, cache_path=cache_path) + + await manager.close() + + assert manager.closed + assert cache.discarded + + +@pytest.mark.asyncio +async def test_file_cache_manager_after_closing(): + server = FakeRemoteFileServer() + + with server.start(): + client = FakeRemoteFileClient(server) + + with client.protocol.follow(): + with create_temporary_file() as file_path, create_temporary_file() as cache_path: + with open(file_path, "wb") as f: + f.write(b"Flat is better than nested.") + file = await client.upload_file(file_path, "assistants", {}) + + async with create_file_cache_manager() as manager: + await manager.close() + + with pytest.raises(ObjectClosedError): + await manager.get_or_create_cache(file.id, cache_path=cache_path) + + with pytest.raises(ObjectClosedError): + await manager.get_cache(file.id) + + with pytest.raises(ObjectClosedError): + await manager.remove_cache(file.id) + + +@pytest.mark.asyncio +async def test_file_cache_manager_auto_remove_unreachable_cache(): + server = FakeRemoteFileServer() + + with server.start(): + client = FakeRemoteFileClient(server) + + with client.protocol.follow(): + with create_temporary_file() as file_path, create_temporary_file() as cache_path: + with open(file_path, "wb") as f: + f.write(b"Flat is better than nested.") + file = await client.upload_file(file_path, "assistants", {}) + + async with create_file_cache_manager() as manager: + cache, _ = await manager.get_or_create_cache(file.id, cache_path=cache_path) + + del cache + + with pytest.raises(CacheNotFoundError): + await manager.get_cache(file.id) + + +@pytest.mark.parametrize("active", [False, True]) +@pytest.mark.asyncio +async def test_file_cache_init_active(active): + with create_temporary_file() as cache_path: + cache = create_file_cache( + cache_path=cache_path, + active=active, + ) + + if active: + assert cache.active + else: + assert not cache.active + + +@pytest.mark.asyncio +async def test_file_cache_timeout(): + expire_after = 0.1 + + with create_temporary_file() as cache_path: + cache = create_file_cache( + cache_path=cache_path, + active=True, + expire_after=expire_after, + ) + await asyncio.sleep(expire_after * 1.5) + + assert not cache.active + + +@pytest.mark.asyncio +async def test_file_cache_fetch_or_update_contents(): + with create_temporary_file() as cache_path: + cache = create_file_cache( + cache_path=cache_path, + active=False, + ) + + contents1 = b"Special cases aren't special enough to break the rules." + result = await cache.fetch_or_update_contents(repeat_bytes_in_coro_func(contents1)) + assert result == contents1 + assert cache.cache_path.read_bytes() == contents1 + + contents2 = b"Although practicality beats purity." + result = await cache.fetch_or_update_contents(repeat_bytes_in_coro_func(contents2)) + assert result == contents1 + assert cache.cache_path.read_bytes() == contents1 + + +@pytest.mark.asyncio +async def test_file_cache_update_contents(): + with create_temporary_file() as cache_path: + cache = create_file_cache( + cache_path=cache_path, + active=False, + ) + + contents1 = b"Errors should never pass silently." + result = await cache.update_contents(repeat_bytes_in_coro_func(contents1)) + assert result == contents1 + assert cache.cache_path.read_bytes() == contents1 + + contents2 = b"Unless explicitly silenced." + result = await cache.update_contents(repeat_bytes_in_coro_func(contents2)) + assert result == contents2 + assert cache.cache_path.read_bytes() == contents2 + + +@pytest.mark.asyncio +async def test_file_cache_activate_deactivate(): + with create_temporary_file() as cache_path: + cache = create_file_cache( + cache_path=cache_path, + active=False, + ) + + cache.activate() + assert cache.active + + cache.deactivate() + assert not cache.active + + +@pytest.mark.asyncio +async def test_file_cache_discard(): + with create_temporary_file() as cache_path: + cache = create_file_cache( + cache_path=cache_path, + active=True, + ) + + await cache.discard() + + assert cache.discarded + assert not cache.alive + assert not cache.active + + +@pytest.mark.asyncio +async def test_file_cache_fetch_or_update_contents_called_concurrently(): + contents = b"Explicit is better than implicit." + num_coros = 4 + + mock_ = mock.Mock(return_value=AwaitableContents(contents)) + + with create_temporary_file() as cache_path: + cache = create_file_cache( + cache_path=cache_path, + active=False, + ) + + results = await asyncio.gather(*[cache.fetch_or_update_contents(mock_) for _ in range(num_coros)]) + + for result in results: + assert result == contents + assert mock_.call_count == 1 + + +@pytest.mark.asyncio +async def test_file_cache_update_contents_called_concurrently(): + contents = b"Readability counts." + num_coros = 4 + + mock_ = mock.Mock(return_value=AwaitableContents(contents)) + + with create_temporary_file() as cache_path: + cache = create_file_cache( + cache_path=cache_path, + active=False, + ) + + results = await asyncio.gather(*[cache.update_contents(mock_) for _ in range(num_coros)]) + + for result in results: + assert result == contents + assert mock_.call_count == num_coros + + +@pytest.mark.asyncio +async def test_file_cache_discard_called_concurrently(): + num_coros = 4 + + mock_ = mock.Mock() + + with create_temporary_file() as cache_path: + cache = create_file_cache( + cache_path=cache_path, + active=False, + discard_callback=mock_, + ) + + await asyncio.gather(*[cache.discard() for _ in range(num_coros)]) + + assert cache.discarded + assert mock_.call_count == 1 + + +@pytest.mark.asyncio +async def test_file_cache_discard_callback(): + mock_ = mock.Mock() + + with create_temporary_file() as cache_path: + cache = create_file_cache( + cache_path=cache_path, + active=False, + discard_callback=mock_, + ) + + await cache.discard() + mock_.assert_called_once_with() + + +@pytest.mark.asyncio +async def test_file_cache_discard_on_destruction(): + mock_ = mock.Mock() + + with create_temporary_file() as cache_path: + cache = create_file_cache( + cache_path=cache_path, + active=False, + discard_callback=mock_, + ) + + del cache + mock_.assert_called_once_with() + + +@pytest.mark.asyncio +async def test_file_cache_after_discarding(): + with create_temporary_file() as cache_path: + cache = create_file_cache( + cache_path=cache_path, + active=True, + ) + + await cache.discard() + + with pytest.raises(CacheDiscardedError): + await cache.fetch_or_update_contents( + repeat_bytes_in_coro_func(b"If the implementation is hard to explain, it's a bad idea.") + ) + + with pytest.raises(CacheDiscardedError): + await cache.update_contents( + repeat_bytes_in_coro_func( + b"If the implementation is easy to explain, it may be a good idea." + ) + ) + + with pytest.raises(CacheDiscardedError): + cache.activate() + + +@pytest.mark.asyncio +async def test_file_cache_copy(): + with create_temporary_file() as cache_path: + cache = create_file_cache( + cache_path=cache_path, + active=False, + ) + + with pytest.raises(RuntimeError): + copy.copy(cache) + + with pytest.raises(RuntimeError): + copy.deepcopy(cache) diff --git a/erniebot-agent/tests/unit_tests/file_io/test_local_file.py b/erniebot-agent/tests/unit_tests/file_io/test_local_file.py new file mode 100644 index 000000000..b06052ba3 --- /dev/null +++ b/erniebot-agent/tests/unit_tests/file_io/test_local_file.py @@ -0,0 +1,17 @@ +import pathlib +import tempfile + +import erniebot_agent.file_io.protocol as protocol +from erniebot_agent.file_io.local_file import LocalFile, create_local_file_from_path + + +def test_create_local_file_from_path(): + with tempfile.TemporaryDirectory() as td: + file_path = pathlib.Path(td) / "temp_file" + file_path.touch() + file_purpose = "assistants" + + file = create_local_file_from_path(file_path, file_purpose, {}) + + assert isinstance(file, LocalFile) + assert protocol.is_local_file_id(file.id) diff --git a/erniebot-agent/tests/unit_tests/testing_utils/mocks/mock_remote_file_client_server.py b/erniebot-agent/tests/unit_tests/testing_utils/mocks/mock_remote_file_client_server.py new file mode 100644 index 000000000..188173f7a --- /dev/null +++ b/erniebot-agent/tests/unit_tests/testing_utils/mocks/mock_remote_file_client_server.py @@ -0,0 +1,187 @@ +import contextlib +import re +import uuid + +import erniebot_agent.file_io.protocol as protocol +from erniebot_agent.file_io.remote_file import RemoteFile, RemoteFileClient + + +class FakeRemoteFileProtocol(object): + _UUID_PATTERN = r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}" + _FILE_ID_PREFIX = r"file-fake-remote-" + _FILE_ID_PATTERN = _FILE_ID_PREFIX + _UUID_PATTERN + + _followed = False + + @classmethod + def is_remote_file_id(cls, str_): + return re.fullmatch(cls._FILE_ID_PATTERN, str_) is not None + + @classmethod + def extract_remote_file_ids(cls, str_): + return re.findall(cls._FILE_ID_PATTERN, str_) + + @classmethod + def generate_remote_file_id(cls): + return cls._FILE_ID_PREFIX + str(uuid.uuid4()) + + get_timestamp = protocol.get_timestamp + + @classmethod + @contextlib.contextmanager + def follow(cls, old_protocol=None): + if not cls._followed: + names_methods_to_monkey_patch = ( + "is_remote_file_id", + "extract_remote_file_ids", + "get_timestamp", + ) + + if old_protocol is None: + old_protocol = protocol + + _old_methods = {} + + for method_name in names_methods_to_monkey_patch: + old_method = getattr(old_protocol, method_name) + new_method = getattr(cls, method_name) + _old_methods[method_name] = old_method + setattr(old_protocol, method_name, new_method) + + cls._followed = True + + yield + + for method_name in names_methods_to_monkey_patch: + old_method = _old_methods[method_name] + setattr(old_protocol, method_name, old_method) + + cls._followed = False + + else: + yield + + +class FakeRemoteFileClient(RemoteFileClient): + _protocol = FakeRemoteFileProtocol + + def __init__(self, server): + super().__init__() + if server.protocol is not self.protocol: + raise ValueError("Server and client do not share the same protocol.") + self._server = server + + @property + def protocol(self): + return self._protocol + + @property + def server(self): + if not self._server.started: + raise RuntimeError("Server is not running.") + return self._server + + async def upload_file(self, file_path, file_purpose, file_metadata): + result = await self.server.upload_file(file_path, file_purpose, file_metadata) + return self._create_file_obj_from_dict(result) + + async def retrieve_file(self, file_id): + result = await self.server.retrieve_file(file_id) + return self._create_file_obj_from_dict(result) + + async def retrieve_file_contents(self, file_id): + return await self.server.retrieve_file_contents(file_id) + + async def list_files(self): + result = await self.server.list_files() + files = [] + for item in result: + file = self._create_file_obj_from_dict(item) + files.append(file) + return files + + async def delete_file(self, file_id) -> None: + await self.server.delete_file(file_id) + + async def create_temporary_url(self, file_id, expire_after): + raise RuntimeError("Method not supported") + + def _create_file_obj_from_dict(self, dict_): + with self._protocol.follow(): + return RemoteFile( + id=dict_["id"], + filename=dict_["filename"], + byte_size=dict_["byte_size"], + created_at=dict_["created_at"], + purpose=dict_["purpose"], + metadata=dict_["metadata"], + client=self, + ) + + +class FakeRemoteFileServer(object): + _protocol = FakeRemoteFileProtocol + + def __init__(self): + super().__init__() + self._storage = None + + @property + def protocol(self): + return self._protocol + + @property + def storage(self): + return self._storage + + @property + def started(self): + return self._storage is not None + + async def upload_file(self, file_path, file_purpose, file_metadata): + id_ = self._protocol.generate_remote_file_id() + filename = file_path.name + byte_size = file_path.stat().st_size + created_at = self._protocol.get_timestamp() + with file_path.open("rb") as f: + contents = f.read() + file = dict( + id=id_, + filename=filename, + byte_size=byte_size, + created_at=created_at, + purpose=file_purpose, + metadata=file_metadata, + contents=contents, + ) + self._storage[id_] = file + return file + + async def retrieve_file(self, file_id): + try: + return self._storage[file_id] + except KeyError as e: + raise RuntimeError("File not found") from e + + async def retrieve_file_contents(self, file_id): + try: + file = self._storage[file_id] + except KeyError as e: + raise RuntimeError("File not found") from e + else: + return file["contents"] + + async def list_files(self): + return list(self._storage.values()) + + async def delete_file(self, file_id) -> None: + try: + return self._storage[file_id] + except KeyError as e: + raise RuntimeError("File not found") from e + + @contextlib.contextmanager + def start(self): + self._storage = {} + yield self + self._storage = None diff --git a/erniebot-agent/tests/unit_tests/tools/test_file_in_tool.py b/erniebot-agent/tests/unit_tests/tools/test_file_in_tool.py index 0e81e9163..80fa7314b 100644 --- a/erniebot-agent/tests/unit_tests/tools/test_file_in_tool.py +++ b/erniebot-agent/tests/unit_tests/tools/test_file_in_tool.py @@ -14,7 +14,6 @@ from __future__ import annotations -import asyncio import base64 import json import os @@ -110,7 +109,13 @@ def is_port_in_use(port): return True -class TestToolWithFile(unittest.TestCase): +class TestToolWithFile(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.file_manager = FileManager() + + async def asyncTearDown(self): + await self.file_manager.close() + def avaliable_free_port(self, exclude=None): exclude = exclude or [] for port in range(8000, 9000): @@ -145,7 +150,7 @@ def wait_until_server_is_ready(self): print("waiting for server ...") time.sleep(1) - def test_plugin_schema(self): + async def test_plugin_schema(self): self.wait_until_server_is_ready() with tempfile.TemporaryDirectory() as tempdir: openapi_file = os.path.join(tempdir, "openapi.yaml") @@ -153,25 +158,27 @@ def test_plugin_schema(self): with open(openapi_file, "w", encoding="utf-8") as f: f.write(content) - toolkit = RemoteToolkit.from_openapi_file(openapi_file) + toolkit = RemoteToolkit.from_openapi_file(openapi_file, file_manager=self.file_manager) tool = toolkit.get_tool("getFile") # tool.tool_name should have `tool_name_prefix`` prepended self.assertEqual(tool.tool_name, "TestRemoteTool/v1/getFile") - file_manager = FileManager() - input_file = asyncio.run(file_manager.create_file_from_path(self.file_path)) - result = asyncio.run(tool(file=input_file.id)) + input_file = await self.file_manager.create_file_from_path(self.file_path) + result = await tool(file=input_file.id) self.assertIn("response_file", result) file_id = result["response_file"] - file = file_manager.look_up_file_by_id(file_id=file_id) - content = asyncio.run(file.read_contents()) + file = self.file_manager.look_up_file_by_id(file_id=file_id) + content = await file.read_contents() self.assertEqual(content.decode("utf-8"), self.content) class TestPlainJsonFileParser(unittest.IsolatedAsyncioTestCase): - def setUp(self) -> None: + async def asyncSetUp(self): self.file_manager = FileManager() + async def asyncTearDown(self): + await self.file_manager.close() + def create_fake_response(self, body: dict): the_response = Response() the_response.code = "expired" @@ -224,7 +231,7 @@ async def test_plain_file(self): file_path = os.path.join(temp_dir, "openapi.yaml") with open(file_path, "w+", encoding="utf-8") as f: f.write(yaml_content) - toolkit = RemoteToolkit.from_openapi_file(file_path) + toolkit = RemoteToolkit.from_openapi_file(file_path, file_manager=self.file_manager) response = self.create_fake_response(body) tool = toolkit.get_tools()[-1] @@ -241,9 +248,12 @@ async def test_plain_file(self): class TestJsonNestFileParser(unittest.IsolatedAsyncioTestCase): - def setUp(self) -> None: + async def asyncSetUp(self): self.file_manager = FileManager() + async def asyncTearDown(self): + await self.file_manager.close() + def create_fake_response(self, body: dict): the_response = Response() the_response.code = "expired" @@ -300,7 +310,7 @@ async def test_plain_file(self): file_path = os.path.join(temp_dir, "openapi.yaml") with open(file_path, "w+", encoding="utf-8") as f: f.write(yaml_content) - toolkit = RemoteToolkit.from_openapi_file(file_path) + toolkit = RemoteToolkit.from_openapi_file(file_path, file_manager=self.file_manager) response = self.create_fake_response(body) tool = toolkit.get_tools()[-1] @@ -319,9 +329,12 @@ async def test_plain_file(self): class TestJsonNestListFileParser(unittest.IsolatedAsyncioTestCase): - def setUp(self) -> None: + async def asyncSetUp(self): self.file_manager = FileManager() + async def asyncTearDown(self): + await self.file_manager.close() + def create_fake_response(self, body: dict): the_response = Response() the_response.code = "expired" @@ -381,7 +394,7 @@ async def test_plain_file(self): file_path = os.path.join(temp_dir, "openapi.yaml") with open(file_path, "w+", encoding="utf-8") as f: f.write(yaml_content) - toolkit = RemoteToolkit.from_openapi_file(file_path) + toolkit = RemoteToolkit.from_openapi_file(file_path, file_manager=self.file_manager) response = self.create_fake_response(body) tool = toolkit.get_tools()[-1] diff --git a/erniebot/Makefile b/erniebot/Makefile index f547eff63..cd2c1f7ea 100644 --- a/erniebot/Makefile +++ b/erniebot/Makefile @@ -1,6 +1,9 @@ -.DEFAULT_GOAL = format lint type_check +.DEFAULT_GOAL = dev files_to_format_and_lint = src examples tests +.PHONY: dev +dev: format lint type_check + .PHONY: format format: python -m black $(files_to_format_and_lint) diff --git a/erniebot/src/erniebot/backends/bce.py b/erniebot/src/erniebot/backends/bce.py index c97d31a36..3e6f48ad9 100644 --- a/erniebot/src/erniebot/backends/bce.py +++ b/erniebot/src/erniebot/backends/bce.py @@ -75,8 +75,8 @@ def request( ) except (errors.TokenExpiredError, errors.InvalidTokenError): logging.warning( - "The access token provided is invalid or has expired. " - "An automatic update will be performed before retrying." + "The access token provided is invalid or has expired." + " An automatic update will be performed before retrying." ) access_token = self._auth_manager.update_auth_token() url_with_token = add_query_params(url, [("access_token", access_token)]) @@ -126,8 +126,8 @@ async def arequest( ) except (errors.TokenExpiredError, errors.InvalidTokenError): logging.warning( - "The access token provided is invalid or has expired. " - "An automatic update will be performed before retrying." + "The access token provided is invalid or has expired." + " An automatic update will be performed before retrying." ) # XXX: The default executor is used. access_token = await loop.run_in_executor(None, self._auth_manager.update_auth_token) diff --git a/erniebot/src/erniebot/types.py b/erniebot/src/erniebot/types.py index c52d1d043..78a5065d7 100644 --- a/erniebot/src/erniebot/types.py +++ b/erniebot/src/erniebot/types.py @@ -15,19 +15,9 @@ from __future__ import annotations from dataclasses import dataclass -from typing import ( - IO, - TYPE_CHECKING, - Any, - AsyncIterator, - Dict, - Iterator, - Optional, - TypeVar, -) +from typing import IO, Any, AsyncIterator, Dict, Iterator, Optional, TypeVar -if TYPE_CHECKING: - from typing_extensions import TypeAlias +from typing_extensions import TypeAlias from .response import EBResponse diff --git a/erniebot/tests/test_chat_completion_aio.py b/erniebot/tests/test_chat_completion_aio.py index 4b2e05da9..0ef8ef87b 100644 --- a/erniebot/tests/test_chat_completion_aio.py +++ b/erniebot/tests/test_chat_completion_aio.py @@ -78,6 +78,8 @@ async def test_chat_completion_aio(target, args): erniebot.api_type = "qianfan" - asyncio.run(test_chat_completion_aio(acreate_chat_completion, args=("ernie-turbo",))) + async def main(): + await test_chat_completion_aio(acreate_chat_completion, args=("ernie-turbo",)) + await test_chat_completion_aio(acreate_chat_completion_stream, args=("ernie-turbo",)) - asyncio.run(test_chat_completion_aio(acreate_chat_completion_stream, args=("ernie-turbo",))) + asyncio.run(main())