Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
import json
import logging
import os
import tempfile
import threading
from collections.abc import AsyncIterable, AsyncIterator, Generator, Mapping, Sequence
from typing import cast
from contextlib import suppress
from typing import Protocol, cast

from agent_framework import (
ChatOptions,
Expand Down Expand Up @@ -109,11 +112,105 @@
logger = logging.getLogger(__name__)


class ApprovalStorage(Protocol):
"""Storage for saving function approval requests."""

async def save_approval_request(self, approval_request_id: str, request: Content) -> None:
"""Save a function approval request under the given ID."""
...

async def load_approval_request(self, approval_request_id: str) -> Content:
"""Load a function approval request by its ID."""
...


class InMemoryFunctionApprovalStorage:
"""An in-memory storage for function approval requests."""

def __init__(self) -> None:
self._store: dict[str, Content] = {}

async def save_approval_request(self, approval_request_id: str, request: Content) -> None:
if approval_request_id in self._store:
raise ValueError(f"Approval request with ID '{approval_request_id}' already exists.")
self._store[approval_request_id] = request

async def load_approval_request(self, approval_request_id: str) -> Content:
if approval_request_id not in self._store:
raise KeyError(f"Approval request with ID '{approval_request_id}' does not exist.")
return self._store[approval_request_id]


class FileBasedFunctionApprovalStorage:
"""A simple file-based storage for function approval requests.

Concurrent writes from multiple threads in the same process are
serialized by a ``threading.Lock``, and the on-disk JSON file is
updated atomically (write to a temp file, then ``os.replace``) so a
crash mid-write cannot leave a partially written file behind.
"""

def __init__(self, storage_path: str) -> None:
self._storage_path = storage_path
self._lock = threading.Lock()

def _create_storage_file_if_not_exists_sync(self) -> None:
"""Lazy-create the storage file (and its parent directory) if it does not already exist.

Uses exclusive-create mode (``"x"``) so a concurrent creator cannot
be truncated by an ``open(..., "w")`` after a stale existence check.
"""
os.makedirs(os.path.dirname(self._storage_path) or ".", exist_ok=True)
with suppress(FileExistsError), open(self._storage_path, "x") as f:
json.dump({}, f)

def _atomic_write(self, data: dict[str, Any]) -> None:
"""Atomically replace the storage file with the serialized ``data``."""
directory = os.path.dirname(self._storage_path) or "."
# Serialize first so any error doesn't leave a partial file behind.
serialized = json.dumps(data)
fd, tmp_path = tempfile.mkstemp(prefix=".approvals-", suffix=".tmp", dir=directory)
try:
with os.fdopen(fd, "w") as tmp:
tmp.write(serialized)
os.replace(tmp_path, self._storage_path)
except BaseException:
with suppress(OSError):
os.unlink(tmp_path)
raise

def _save_sync(self, approval_request_id: str, request: Content) -> None:
with self._lock:
self._create_storage_file_if_not_exists_sync()
with open(self._storage_path) as f:
data = json.load(f)
if approval_request_id in data:
raise ValueError(f"Approval request with ID '{approval_request_id}' already exists.")
data[approval_request_id] = request.to_dict()
self._atomic_write(data)

def _load_sync(self, approval_request_id: str) -> Content:
with self._lock:
self._create_storage_file_if_not_exists_sync()
with open(self._storage_path) as f:
data = json.load(f)
if approval_request_id not in data:
raise KeyError(f"Approval request with ID '{approval_request_id}' does not exist.")
return Content.from_dict(data[approval_request_id])

async def save_approval_request(self, approval_request_id: str, request: Content) -> None:
await asyncio.to_thread(self._save_sync, approval_request_id, request)

async def load_approval_request(self, approval_request_id: str) -> Content:
return await asyncio.to_thread(self._load_sync, approval_request_id)


class ResponsesHostServer(ResponsesAgentServerHost):
"""A responses server host for an agent."""

# TODO(@taochen): Allow a different checkpoint storage that stores checkpoints externally
CHECKPOINT_STORAGE_PATH = "/.checkpoints"
FUNCTION_APPROVAL_STORAGE_PATH = "/.function_approvals/approval_requests.json"

def __init__(
self,
Expand Down Expand Up @@ -171,6 +268,11 @@ def __init__(
self._is_workflow_agent = True

self._agent = agent
self._approval_storage = (
FileBasedFunctionApprovalStorage(self.FUNCTION_APPROVAL_STORAGE_PATH)
if self.config.is_hosted
else InMemoryFunctionApprovalStorage()
)
self.response_handler(self._handle_response) # pyright: ignore[reportUnknownMemberType]

async def _handle_response(
Expand All @@ -192,10 +294,15 @@ async def _handle_inner_agent(
) -> AsyncIterable[ResponseStreamEvent | dict[str, Any]]:
"""Handle the creation of a response for a regular (non-workflow) agent."""
input_items = await context.get_input_items()
input_messages = _items_to_messages(input_items)
input_messages = await _items_to_messages(input_items, approval_storage=self._approval_storage)

history = await context.get_history()
run_kwargs: dict[str, Any] = {"messages": [*_output_items_to_messages(history), *input_messages]}
run_kwargs: dict[str, Any] = {
"messages": [
*(await _output_items_to_messages(history, approval_storage=self._approval_storage)),
*input_messages,
]
}
is_streaming_request = request.stream is not None and request.stream is True

chat_options, are_options_set = _to_chat_options(request)
Expand All @@ -216,7 +323,11 @@ async def _handle_inner_agent(

for message in response.messages:
for content in message.contents:
async for item in _to_outputs(response_event_stream, content):
async for item in _to_outputs(
response_event_stream,
content,
approval_storage=self._approval_storage,
):
yield item

yield response_event_stream.emit_completed()
Expand All @@ -232,7 +343,11 @@ async def _handle_inner_agent(
for event in tracker.handle(content):
yield event
if tracker.needs_async:
async for item in _to_outputs(response_event_stream, content):
async for item in _to_outputs(
response_event_stream,
content,
approval_storage=self._approval_storage,
):
yield item
tracker.needs_async = False

Expand All @@ -254,7 +369,7 @@ async def _handle_inner_workflow(
by the hosting infrastructure or files will be preserved upon deactivation.
"""
input_items = await context.get_input_items()
input_messages = _items_to_messages(input_items)
input_messages = await _items_to_messages(input_items)
Comment thread
TaoChenOSU marked this conversation as resolved.
is_streaming_request = request.stream is not None and request.stream is True

_, are_options_set = _to_chat_options(request)
Expand Down Expand Up @@ -581,26 +696,32 @@ def _to_chat_options(request: CreateResponse) -> tuple[ChatOptions, bool]:
# region Input Message Conversion


def _items_to_messages(input_items: Sequence[Item]) -> list[Message]:
async def _items_to_messages(
input_items: Sequence[Item], *, approval_storage: ApprovalStorage | None = None
) -> list[Message]:
"""Converts a sequence of input items to a list of Messages, one per item.

Args:
input_items: The input items to convert.
approval_storage: An optional ApprovalStorage instance used to look up
approval requests when converting MCP approval response items.

Returns:
A list of Messages, one per supported input item.
"""
messages: list[Message] = []
for item in input_items:
messages.append(_item_to_message(item))
messages.append(await _item_to_message(item, approval_storage=approval_storage))
return messages


def _item_to_message(item: Item) -> Message:
async def _item_to_message(item: Item, *, approval_storage: ApprovalStorage | None = None) -> Message:
"""Converts an Item to a Message.

Args:
item: The Item to convert.
approval_storage: An optional ApprovalStorage instance used to look up
approval requests when converting MCP approval response items.

Returns:
The converted Message.
Expand Down Expand Up @@ -659,27 +780,26 @@ def _item_to_message(item: Item) -> Message:

if item.type == "mcp_approval_request":
mcp_req = cast(ItemMcpApprovalRequest, item)
mcp_call_content = Content.from_mcp_server_tool_call(
mcp_req.id,
mcp_req.name,
server_name=mcp_req.server_label,
arguments=mcp_req.arguments,
)
if approval_storage is not None:
function_approval_request_content = await approval_storage.load_approval_request(mcp_req.id)
else:
raise ValueError("ApprovalStorage is required to load approval request.")
return Message(
role="assistant",
contents=[Content.from_function_approval_request(mcp_req.id, mcp_call_content)],
contents=[function_approval_request_content],
)

if item.type == "mcp_approval_response":
mcp_resp = cast(MCPApprovalResponse, item)
placeholder_content = Content.from_function_call(mcp_resp.approval_request_id, "mcp_approval")
if approval_storage is not None:
function_approval_request_content = await approval_storage.load_approval_request(
mcp_resp.approval_request_id
)
else:
raise ValueError("ApprovalStorage is required to load approval request.")
return Message(
role="user",
contents=[
Content.from_function_approval_response(
mcp_resp.approve, mcp_resp.approval_request_id, placeholder_content
)
],
contents=[function_approval_request_content.to_function_approval_response(mcp_resp.approve)],
)

if item.type == "code_interpreter_call":
Expand Down Expand Up @@ -846,26 +966,34 @@ def _item_to_message(item: Item) -> Message:
raise ValueError(f"Unsupported Item type: {item.type}")


def _output_items_to_messages(history: Sequence[OutputItem]) -> list[Message]:
async def _output_items_to_messages(
history: Sequence[OutputItem],
*,
approval_storage: ApprovalStorage | None = None,
) -> list[Message]:
"""Converts a sequence of OutputItem objects to a list of Message objects.

Args:
history (Sequence[OutputItem]): The sequence of OutputItem objects to convert.
approval_storage (ApprovalStorage | None, optional): The approval storage to use for
resolving MCP approval requests. Defaults to None.

Returns:
list[Message]: The list of Message objects.
"""
messages: list[Message] = []
for item in history:
messages.append(_output_item_to_message(item))
messages.append(await _output_item_to_message(item, approval_storage=approval_storage))
return messages


def _output_item_to_message(item: OutputItem) -> Message:
async def _output_item_to_message(item: OutputItem, *, approval_storage: ApprovalStorage | None = None) -> Message:
"""Converts an OutputItem to a Message.

Args:
item (OutputItem): The OutputItem to convert.
approval_storage (ApprovalStorage | None, optional): The approval storage to use for
resolving MCP approval requests. Defaults to None.

Returns:
Message: The converted Message.
Expand Down Expand Up @@ -922,24 +1050,27 @@ def _output_item_to_message(item: OutputItem) -> Message:

if item.type == "mcp_approval_request":
mcp_req = cast(OutputItemMcpApprovalRequest, item)
mcp_call_content = Content.from_mcp_server_tool_call(
mcp_req.id,
mcp_req.name,
server_name=mcp_req.server_label,
arguments=mcp_req.arguments,
)
if approval_storage is not None:
function_approval_request_content = await approval_storage.load_approval_request(mcp_req.id)
else:
raise ValueError("ApprovalStorage is required to load approval request.")
return Message(
role="assistant",
Comment thread
TaoChenOSU marked this conversation as resolved.
contents=[Content.from_function_approval_request(mcp_req.id, mcp_call_content)],
contents=[function_approval_request_content],
)

if item.type == "mcp_approval_response":
mcp_resp = cast(OutputItemMcpApprovalResponseResource, item)
# Build a placeholder function_call Content since the original call details are not available
placeholder_content = Content.from_function_call(mcp_resp.approval_request_id, "mcp_approval")
if approval_storage is not None:
function_approval_request_content = await approval_storage.load_approval_request(
mcp_resp.approval_request_id
)
else:
raise ValueError("ApprovalStorage is required to load approval request.")

return Message(
role="user",
contents=[Content.from_function_approval_response(mcp_resp.approve, mcp_resp.id, placeholder_content)],
contents=[function_approval_request_content.to_function_approval_response(mcp_resp.approve)],
)

if item.type == "code_interpreter_call":
Expand Down Expand Up @@ -1237,12 +1368,18 @@ def _arguments_to_str(arguments: str | Mapping[str, Any] | None) -> str:
return json.dumps(arguments)


async def _to_outputs(stream: ResponseEventStream, content: Content) -> AsyncIterator[ResponseStreamEvent]:
async def _to_outputs(
stream: ResponseEventStream,
content: Content,
*,
approval_storage: ApprovalStorage | None = None,
) -> AsyncIterator[ResponseStreamEvent]:
"""Converts a Content object to an async sequence of ResponseStreamEvent objects.

Args:
stream: The ResponseEventStream to use for building events.
content: The Content to convert.
approval_storage: An optional ApprovalStorage instance to use for saving and loading function approval requests.

Yields:
ResponseStreamEvent: The converted event objects.
Expand Down Expand Up @@ -1320,6 +1457,31 @@ async def _to_outputs(stream: ResponseEventStream, content: Content) -> AsyncIte
max_output_length=content.max_output_length,
):
yield event
elif content.type == "function_approval_request":
function_call: Content = content.function_call # type: ignore
server_label = function_call.additional_properties.get("server_label", "agent_framework")
request_saved = False
async for event in stream.aoutput_item_mcp_approval_request(
server_label,
function_call.name, # type: ignore
_arguments_to_str(function_call.arguments),
):
if approval_storage is not None and not request_saved:
# Extract the approval request ID generated by the infrastructure
# when the approval request item is added to the stream. Save the
# approval request to the approval storage so it can be retrieved later
# for round trips where the original approval request needs to be looked up.
item = getattr(event, "item", None)
if item is not None and getattr(item, "id", None) is not None:
Comment thread
TaoChenOSU marked this conversation as resolved.
approval_request_id = cast(str, item.id) # type: ignore
await approval_storage.save_approval_request(approval_request_id, content)
request_saved = True
yield event
if approval_storage is not None and not request_saved:
logger.warning(
"Approval request was not saved to approval storage because the approval request ID "
"could not be extracted from the stream event."
)
else:
# Log a warning for unsupported content types instead of raising an error to avoid breaking the response stream.
logger.warning(f"Content type '{content.type}' is not supported yet. This is usually safe to ignore.")
Expand Down
Loading
Loading