From b2803333e13348ec3141d671f8ff6b8e7936d5ec Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Thu, 21 Dec 2023 20:55:59 +0800 Subject: [PATCH] Fix protocol --- .../src/erniebot_agent/file_io/protocol.py | 11 ++- .../chat_models/test_chat_model.py | 1 - .../mocks/mock_remote_file_client_server.py | 98 +++---------------- 3 files changed, 22 insertions(+), 88 deletions(-) diff --git a/erniebot-agent/src/erniebot_agent/file_io/protocol.py b/erniebot-agent/src/erniebot_agent/file_io/protocol.py index f636dfe0f..6caad1424 100644 --- a/erniebot-agent/src/erniebot_agent/file_io/protocol.py +++ b/erniebot-agent/src/erniebot_agent/file_io/protocol.py @@ -14,7 +14,7 @@ import datetime import re -from typing import List, Literal +from typing import Generator, List, Literal from typing_extensions import TypeAlias @@ -23,7 +23,8 @@ _LOCAL_FILE_ID_PREFIX = "file-local-" _UUID_PATTERN = r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}" _LOCAL_FILE_ID_PATTERN = _LOCAL_FILE_ID_PREFIX + _UUID_PATTERN -_REMOTE_FILE_ID_PATTERN = r"file-[0-9]{15}" +_REMOTE_FILE_ID_PREFIX = "file-" +_REMOTE_FILE_ID_PATTERN = _REMOTE_FILE_ID_PREFIX + r"[0-9]{15}" _compiled_local_file_id_pattern = re.compile(_LOCAL_FILE_ID_PATTERN) _compiled_remote_file_id_pattern = re.compile(_REMOTE_FILE_ID_PATTERN) @@ -59,3 +60,9 @@ def extract_local_file_ids(str_: str) -> List[str]: def extract_remote_file_ids(str_: str) -> List[str]: return _compiled_remote_file_id_pattern.findall(str_) + + +def generate_fake_remote_file_ids() -> Generator[str, None, None]: + counter = 0 + while True: + yield _REMOTE_FILE_ID_PREFIX + f"{counter:015d}" diff --git a/erniebot-agent/tests/integration_tests/chat_models/test_chat_model.py b/erniebot-agent/tests/integration_tests/chat_models/test_chat_model.py index 8c8f5ca4e..0932f3c97 100644 --- a/erniebot-agent/tests/integration_tests/chat_models/test_chat_model.py +++ b/erniebot-agent/tests/integration_tests/chat_models/test_chat_model.py @@ -1,4 +1,3 @@ -import os import unittest import pytest diff --git a/erniebot-agent/tests/unit_tests/testing_utils/mocks/mock_remote_file_client_server.py b/erniebot-agent/tests/unit_tests/testing_utils/mocks/mock_remote_file_client_server.py index 9731127c7..c07bfe8fb 100644 --- a/erniebot-agent/tests/unit_tests/testing_utils/mocks/mock_remote_file_client_server.py +++ b/erniebot-agent/tests/unit_tests/testing_utils/mocks/mock_remote_file_client_server.py @@ -1,80 +1,14 @@ import contextlib -import re -import uuid -import erniebot_agent.file_io.protocol as protocol +from erniebot_agent.file_io import protocol from erniebot_agent.file_io.remote_file import RemoteFile, RemoteFileClient -class FakeRemoteFileProtocol(object): - _UUID_PATTERN = r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}" - _FILE_ID_PREFIX = r"file-fake-remote-" - _FILE_ID_PATTERN = _FILE_ID_PREFIX + _UUID_PATTERN - - _followed = False - - @classmethod - def is_remote_file_id(cls, str_): - return re.fullmatch(cls._FILE_ID_PATTERN, str_) is not None - - @classmethod - def extract_remote_file_ids(cls, str_): - return re.findall(cls._FILE_ID_PATTERN, str_) - - @classmethod - def generate_remote_file_id(cls): - return cls._FILE_ID_PREFIX + str(uuid.uuid4()) - - get_timestamp = protocol.get_timestamp - - @classmethod - @contextlib.contextmanager - def follow(cls, old_protocol=None): - if not cls._followed: - names_methods_to_monkey_patch = ( - "is_remote_file_id", - "extract_remote_file_ids", - "get_timestamp", - ) - - if old_protocol is None: - old_protocol = protocol - - _old_methods = {} - - for method_name in names_methods_to_monkey_patch: - old_method = getattr(old_protocol, method_name) - new_method = getattr(cls, method_name) - _old_methods[method_name] = old_method - setattr(old_protocol, method_name, new_method) - - cls._followed = True - - yield - - for method_name in names_methods_to_monkey_patch: - old_method = _old_methods[method_name] - setattr(old_protocol, method_name, old_method) - - cls._followed = False - - else: - yield - - class FakeRemoteFileClient(RemoteFileClient): - _protocol = FakeRemoteFileProtocol - def __init__(self, server): super().__init__() - if server.protocol is not self.protocol: - raise ValueError("Server and client do not share the same protocol.") self._server = server - @property - def protocol(self): - return self._protocol - @property def server(self): if not self._server.started: @@ -107,28 +41,22 @@ async def create_temporary_url(self, file_id, expire_after): raise RuntimeError("Method not supported") def _create_file_obj_from_dict(self, dict_): - with self._protocol.follow(): - return RemoteFile( - id=dict_["id"], - filename=dict_["filename"], - byte_size=dict_["byte_size"], - created_at=dict_["created_at"], - purpose=dict_["purpose"], - metadata=dict_["metadata"], - client=self, - ) + return RemoteFile( + id=dict_["id"], + filename=dict_["filename"], + byte_size=dict_["byte_size"], + created_at=dict_["created_at"], + purpose=dict_["purpose"], + metadata=dict_["metadata"], + client=self, + ) class FakeRemoteFileServer(object): - _protocol = FakeRemoteFileProtocol - def __init__(self): super().__init__() self._storage = None - - @property - def protocol(self): - return self._protocol + self._file_id_iter = protocol.generate_fake_remote_file_ids() @property def storage(self): @@ -139,10 +67,10 @@ def started(self): return self._storage is not None async def upload_file(self, file_path, file_purpose, file_metadata): - id_ = self._protocol.generate_remote_file_id() + id_ = next(self._file_id_iter) filename = file_path.name byte_size = file_path.stat().st_size - created_at = self._protocol.get_timestamp() + created_at = protocol.get_timestamp() with file_path.open("rb") as f: contents = f.read() file = dict(