Skip to content

Commit

Permalink
Fix protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
Bobholamovic committed Dec 21, 2023
1 parent bc8bbd8 commit b280333
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 88 deletions.
11 changes: 9 additions & 2 deletions erniebot-agent/src/erniebot_agent/file_io/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import datetime
import re
from typing import List, Literal
from typing import Generator, List, Literal

from typing_extensions import TypeAlias

Expand All @@ -23,7 +23,8 @@
_LOCAL_FILE_ID_PREFIX = "file-local-"
_UUID_PATTERN = r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
_LOCAL_FILE_ID_PATTERN = _LOCAL_FILE_ID_PREFIX + _UUID_PATTERN
_REMOTE_FILE_ID_PATTERN = r"file-[0-9]{15}"
_REMOTE_FILE_ID_PREFIX = "file-"
_REMOTE_FILE_ID_PATTERN = _REMOTE_FILE_ID_PREFIX + r"[0-9]{15}"

_compiled_local_file_id_pattern = re.compile(_LOCAL_FILE_ID_PATTERN)
_compiled_remote_file_id_pattern = re.compile(_REMOTE_FILE_ID_PATTERN)
Expand Down Expand Up @@ -59,3 +60,9 @@ def extract_local_file_ids(str_: str) -> List[str]:

def extract_remote_file_ids(str_: str) -> List[str]:
return _compiled_remote_file_id_pattern.findall(str_)


def generate_fake_remote_file_ids() -> Generator[str, None, None]:
counter = 0
while True:
yield _REMOTE_FILE_ID_PREFIX + f"{counter:015d}"
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import unittest

import pytest
Expand Down
Original file line number Diff line number Diff line change
@@ -1,80 +1,14 @@
import contextlib
import re
import uuid

import erniebot_agent.file_io.protocol as protocol
from erniebot_agent.file_io import protocol
from erniebot_agent.file_io.remote_file import RemoteFile, RemoteFileClient


class FakeRemoteFileProtocol(object):
_UUID_PATTERN = r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
_FILE_ID_PREFIX = r"file-fake-remote-"
_FILE_ID_PATTERN = _FILE_ID_PREFIX + _UUID_PATTERN

_followed = False

@classmethod
def is_remote_file_id(cls, str_):
return re.fullmatch(cls._FILE_ID_PATTERN, str_) is not None

@classmethod
def extract_remote_file_ids(cls, str_):
return re.findall(cls._FILE_ID_PATTERN, str_)

@classmethod
def generate_remote_file_id(cls):
return cls._FILE_ID_PREFIX + str(uuid.uuid4())

get_timestamp = protocol.get_timestamp

@classmethod
@contextlib.contextmanager
def follow(cls, old_protocol=None):
if not cls._followed:
names_methods_to_monkey_patch = (
"is_remote_file_id",
"extract_remote_file_ids",
"get_timestamp",
)

if old_protocol is None:
old_protocol = protocol

_old_methods = {}

for method_name in names_methods_to_monkey_patch:
old_method = getattr(old_protocol, method_name)
new_method = getattr(cls, method_name)
_old_methods[method_name] = old_method
setattr(old_protocol, method_name, new_method)

cls._followed = True

yield

for method_name in names_methods_to_monkey_patch:
old_method = _old_methods[method_name]
setattr(old_protocol, method_name, old_method)

cls._followed = False

else:
yield


class FakeRemoteFileClient(RemoteFileClient):
_protocol = FakeRemoteFileProtocol

def __init__(self, server):
super().__init__()
if server.protocol is not self.protocol:
raise ValueError("Server and client do not share the same protocol.")
self._server = server

@property
def protocol(self):
return self._protocol

@property
def server(self):
if not self._server.started:
Expand Down Expand Up @@ -107,28 +41,22 @@ async def create_temporary_url(self, file_id, expire_after):
raise RuntimeError("Method not supported")

def _create_file_obj_from_dict(self, dict_):
with self._protocol.follow():
return RemoteFile(
id=dict_["id"],
filename=dict_["filename"],
byte_size=dict_["byte_size"],
created_at=dict_["created_at"],
purpose=dict_["purpose"],
metadata=dict_["metadata"],
client=self,
)
return RemoteFile(
id=dict_["id"],
filename=dict_["filename"],
byte_size=dict_["byte_size"],
created_at=dict_["created_at"],
purpose=dict_["purpose"],
metadata=dict_["metadata"],
client=self,
)


class FakeRemoteFileServer(object):
_protocol = FakeRemoteFileProtocol

def __init__(self):
super().__init__()
self._storage = None

@property
def protocol(self):
return self._protocol
self._file_id_iter = protocol.generate_fake_remote_file_ids()

@property
def storage(self):
Expand All @@ -139,10 +67,10 @@ def started(self):
return self._storage is not None

async def upload_file(self, file_path, file_purpose, file_metadata):
id_ = self._protocol.generate_remote_file_id()
id_ = next(self._file_id_iter)
filename = file_path.name
byte_size = file_path.stat().st_size
created_at = self._protocol.get_timestamp()
created_at = protocol.get_timestamp()
with file_path.open("rb") as f:
contents = f.read()
file = dict(
Expand Down

0 comments on commit b280333

Please sign in to comment.