Skip to content

Commit

Permalink
Fix data race
Browse files Browse the repository at this point in the history
  • Loading branch information
Bobholamovic committed Dec 22, 2023
1 parent 7df8462 commit dc5024a
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 47 deletions.
4 changes: 2 additions & 2 deletions erniebot-agent/examples/cv_agent/CV_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from erniebot_agent.agents.functional_agent import FunctionalAgent
from erniebot_agent.chat_models.erniebot import ERNIEBot
from erniebot_agent.file_io import (
from erniebot_agent.file import (
configure_global_file_manager,
get_global_file_manager,
)
Expand All @@ -25,7 +25,7 @@ def __init__(self):


async def run_agent():
file_manager = get_global_file_manager()
file_manager = await get_global_file_manager()
seg_file = await file_manager.create_file_from_path(file_path="cityscapes_demo.png", file_type="local")
clas_file = await file_manager.create_file_from_path(file_path="class_img.jpg", file_type="local")
ocr_file = await file_manager.create_file_from_path(file_path="ch.png", file_type="local")
Expand Down
10 changes: 5 additions & 5 deletions erniebot-agent/examples/plugins/multiple_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from erniebot_agent.agents.callback.default import get_no_ellipsis_callback
from erniebot_agent.agents.functional_agent import FunctionalAgent
from erniebot_agent.chat_models.erniebot import ERNIEBot
from erniebot_agent.file_io import get_global_file_manager
from erniebot_agent.file import get_global_file_manager
from erniebot_agent.memory.sliding_window_memory import SlidingWindowMemory
from erniebot_agent.messages import AIMessage, HumanMessage, Message
from erniebot_agent.memory import AIMessage, HumanMessage, Message
from erniebot_agent.tools.base import Tool
from erniebot_agent.tools.calculator_tool import CalculatorTool
from erniebot_agent.tools.schema import ToolParameterView
Expand All @@ -32,7 +32,7 @@ async def __call__(self, input_file_id: str, repeat_times: int) -> Dict[str, Any
if "<split>" in input_file_id:
input_file_id = input_file_id.split("<split>")[0]

file_manager = get_global_file_manager()
file_manager = await get_global_file_manager()
input_file = file_manager.look_up_file_by_id(input_file_id)
if input_file is None:
raise RuntimeError("File not found")
Expand Down Expand Up @@ -109,20 +109,20 @@ def examples(self) -> List[Message]:
# TODO(shiyutang): replace this when model is online
llm = ERNIEBot(model="ernie-3.5", api_type="custom")
memory = SlidingWindowMemory(max_round=1)
file_manager = get_global_file_manager()
# plugins = ["ChatFile", "eChart"]
plugins: List[str] = []
agent = FunctionalAgent(
llm=llm,
tools=[TextRepeaterTool(), TextRepeaterNoFileTool(), CalculatorTool()],
memory=memory,
file_manager=file_manager,
callbacks=get_no_ellipsis_callback(),
plugins=plugins,
)


async def run_agent():
file_manager = await get_global_file_manager()

docx_file = await file_manager.create_file_from_path(
file_path="浅谈牛奶的营养与消费趋势.docx",
file_type="remote",
Expand Down
5 changes: 1 addition & 4 deletions erniebot-agent/examples/rpg_game_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@
from erniebot_agent.agents.base import Agent
from erniebot_agent.agents.schema import AgentFile, AgentResponse
from erniebot_agent.chat_models.erniebot import ERNIEBot
from erniebot_agent.file_io import get_global_file_manager
from erniebot_agent.file_io.base import File
from erniebot_agent.file_io.file_manager import FileManager
from erniebot_agent.file.base import File
from erniebot_agent.memory.sliding_window_memory import SlidingWindowMemory
from erniebot_agent.messages import AIMessage, HumanMessage, SystemMessage
from erniebot_agent.tools.base import BaseTool
Expand Down Expand Up @@ -87,7 +85,6 @@ def __init__(
tools=tools,
system_message=system_message,
)
self.file_manager: FileManager = get_global_file_manager()

async def handle_tool(self, tool_name: str, tool_args: str) -> str:
tool_response = await self._async_run_tool(
Expand Down
23 changes: 12 additions & 11 deletions erniebot-agent/src/erniebot_agent/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,16 @@ def __init__(
self._callback_manager = callbacks
else:
self._callback_manager = CallbackManager(callbacks)
if file_manager is None:
file_manager = get_global_file_manager()
self.plugins = plugins
self._file_manager = file_manager
self._plugins = plugins
self._init_file_repr()

def _init_file_repr(self):
self.file_needs_url = False

if self.plugins:
if self._plugins:
PLUGIN_WO_FILE = ["eChart"]
for plugin in self.plugins:
for plugin in self._plugins:
if plugin not in PLUGIN_WO_FILE:
self.file_needs_url = True

Expand Down Expand Up @@ -175,12 +173,8 @@ async def _sniff_and_extract_files_from_args(
for val in args.values():
if isinstance(val, str):
if protocol.is_file_id(val):
if self._file_manager is None:
logger.warning(
f"A file is used by {repr(tool)}, but the agent has no file manager to fetch it."
)
continue
file = self._file_manager.look_up_file_by_id(val)
file_manager = await self._get_file_manager()
file = file_manager.look_up_file_by_id(val)
if file is None:
raise RuntimeError(f"Unregistered file with ID {repr(val)} is used by {repr(tool)}.")
agent_files.append(AgentFile(file=file, type=file_type, used_by=tool.tool_name))
Expand All @@ -190,3 +184,10 @@ async def _sniff_and_extract_files_from_args(
for item in val:
agent_files.extend(await self._sniff_and_extract_files_from_args(item, tool, file_type))
return agent_files

async def _get_file_manager(self) -> FileManager:
if self._file_manager is None:
file_manager = await get_global_file_manager()
else:
file_manager = self._file_manager
return file_manager
18 changes: 11 additions & 7 deletions erniebot-agent/src/erniebot_agent/file/global_file_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from typing import Any, Optional

import asyncio_atexit # type: ignore
Expand All @@ -21,24 +22,27 @@
from erniebot_agent.utils import config_from_environ as C

_global_file_manager: Optional[FileManager] = None
_lock = asyncio.Lock()


async def get_global_file_manager() -> FileManager:
global _global_file_manager
if _global_file_manager is None:
_global_file_manager = await _create_default_file_manager(access_token=None, save_dir=None)
async with _lock:
if _global_file_manager is None:
_global_file_manager = await _create_default_file_manager(access_token=None, save_dir=None)
return _global_file_manager


async def configure_global_file_manager(
access_token: Optional[str] = None, save_dir: Optional[str] = None, **opts: Any
) -> None:
global _global_file_manager
if _global_file_manager is not None:
raise RuntimeError(
"The global file manager can only be configured once before calling `get_global_file_manager`."
)
_global_file_manager = await _create_default_file_manager(access_token=access_token, save_dir=save_dir, **opts)
async with _lock:
if _global_file_manager is not None:
raise RuntimeError(
"The global file manager can only be configured once before calling `get_global_file_manager`."
)
_global_file_manager = await _create_default_file_manager(access_token=access_token, save_dir=save_dir, **opts)


async def _create_default_file_manager(
Expand Down
20 changes: 16 additions & 4 deletions erniebot-agent/src/erniebot_agent/tools/remote_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
from copy import deepcopy
from typing import Any, Dict, List, Optional, Type
from erniebot_agent.file import get_global_file_manager

import requests

Expand Down Expand Up @@ -45,7 +46,7 @@ def __init__(
server_url: str,
headers: dict,
version: str,
file_manager: FileManager,
file_manager: Optional[FileManager],
examples: Optional[List[Message]] = None,
tool_name_prefix: Optional[str] = None,
) -> None:
Expand Down Expand Up @@ -86,11 +87,13 @@ async def fileid_to_byte(file_id, file_manager):

async def convert_to_file_data(file_data: str, format: str):
value = file_data.replace("<file>", "").replace("</file>", "")
byte_value = await fileid_to_byte(value, self.file_manager)
byte_value = await fileid_to_byte(value, file_manager)
if format == "byte":
byte_value = base64.b64encode(byte_value).decode()
return byte_value

file_manager = await self._get_file_manager()

# 1. replace fileid with byte string
parameter_file_info = get_file_info_from_param_view(self.tool_view.parameters)
for key in tool_arguments.keys():
Expand Down Expand Up @@ -203,19 +206,21 @@ async def send_request(self, tool_arguments: Dict[str, Any]) -> dict:
if len(returns_file_infos) == 0 and is_json_response(response):
return response.json()

file_manager = await self._get_file_manager()

file_metadata = {"tool_name": self.tool_name}
if is_json_response(response) and len(returns_file_infos) > 0:
response_json = response.json()
file_info = await parse_file_from_json_response(
response_json,
file_manager=self.file_manager,
file_manager=file_manager,
param_view=self.tool_view.returns, # type: ignore
tool_name=self.tool_name,
)
response_json.update(file_info)
return response_json
file = await parse_file_from_response(
response, self.file_manager, file_infos=returns_file_infos, file_metadata=file_metadata
response, file_manager, file_infos=returns_file_infos, file_metadata=file_metadata
)

if file is not None:
Expand Down Expand Up @@ -293,6 +298,13 @@ def __adhoc_post_process__(self, tool_response: dict) -> dict:
result.pop(key)
return tool_response

async def _get_file_manager(self) -> FileManager:
if self.file_manager is None:
file_manager = await get_global_file_manager()
else:
file_manager = self.file_manager
return file_manager


class RemoteToolRegistor:
def __init__(self) -> None:
Expand Down
4 changes: 0 additions & 4 deletions erniebot-agent/src/erniebot_agent/tools/remote_toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from openapi_spec_validator.readers import read_from_filename
from yaml import safe_dump

from erniebot_agent.file import get_global_file_manager
from erniebot_agent.file.file_manager import FileManager
from erniebot_agent.memory.messages import (
AIMessage,
Expand Down Expand Up @@ -201,9 +200,6 @@ def from_openapi_dict(
)
)

if file_manager is None:
file_manager = get_global_file_manager()

return RemoteToolkit(
openapi=openapi_dict["openapi"],
info=info,
Expand Down
20 changes: 10 additions & 10 deletions erniebot-agent/tests/integration_tests/apihub/test_pp_shituv2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations
from erniebot_agent.file.file_manager import FileManager

import pytest

from erniebot_agent.file_io import get_global_file_manager
from erniebot_agent.tools import RemoteToolkit

from .base import RemoteToolTesting
Expand All @@ -17,13 +17,13 @@ async def test_pp_shituv2(self):

agent = self.get_agent(toolkit)

file_manager = get_global_file_manager()
file_path = self.download_file(
"https://paddlenlp.bj.bcebos.com/ebagent/ci/fixtures/remote-tools/pp_shituv2_input_img.png"
)
file = await file_manager.create_file_from_path(file_path)
async with FileManager() as file_manager:
file_path = self.download_file(
"https://paddlenlp.bj.bcebos.com/ebagent/ci/fixtures/remote-tools/pp_shituv2_input_img.png"
)
file = await file_manager.create_file_from_path(file_path)

result = await agent.async_run("对这张图片进行通用识别,包含的文件为:", files=[file])
self.assertEqual(len(result.files), 2)
self.assertEqual(len(result.actions), 1)
self.assertIn("file-", result.text)
result = await agent.async_run("对这张图片进行通用识别,包含的文件为:", files=[file])
self.assertEqual(len(result.files), 2)
self.assertEqual(len(result.actions), 1)
self.assertIn("file-", result.text)

0 comments on commit dc5024a

Please sign in to comment.