Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Custom Temporal Model Provider with streaming support for OpenAI agents."""
from __future__ import annotations

import json
import uuid
from typing import Any, List, Union, Optional, override

Expand Down Expand Up @@ -60,9 +61,9 @@
from agentex.lib import adk
from agentex.lib.utils.logging import make_logger
from agentex.lib.core.tracing.tracer import AsyncTracer
from agentex.types.task_message_delta import TextDelta, ReasoningContentDelta, ReasoningSummaryDelta
from agentex.types.task_message_delta import TextDelta, ToolRequestDelta, ReasoningContentDelta, ReasoningSummaryDelta
from agentex.types.task_message_update import StreamTaskMessageFull, StreamTaskMessageDelta
from agentex.types.task_message_content import TextContent, ReasoningContent
from agentex.types.task_message_content import TextContent, ReasoningContent, ToolRequestContent
from agentex.lib.adk.utils._modules.client import create_async_agentex_client
from agentex.lib.core.temporal.plugins.openai_agents.interceptors.context_interceptor import (
streaming_task_id,
Expand Down Expand Up @@ -678,12 +679,27 @@ async def get_response(
streaming_mode=self.streaming_mode,
).__aenter__()
elif item and getattr(item, 'type', None) == 'function_call':
# Track the function call being streamed
# Open a streaming context per function call so argument
# deltas can be published incrementally. Coalescing and
# mode dispatch are handled by the streaming layer.
call_id = getattr(item, 'call_id', '')
tool_name = getattr(item, 'name', '')
call_context = await adk.streaming.streaming_task_message_context(
task_id=task_id,
initial_content=ToolRequestContent(
author="agent",
tool_call_id=call_id,
name=tool_name,
arguments={},
),
streaming_mode=self.streaming_mode,
).__aenter__()
function_calls_in_progress[output_index] = {
'id': getattr(item, 'id', ''),
'call_id': getattr(item, 'call_id', ''),
'name': getattr(item, 'name', ''),
'call_id': call_id,
'name': tool_name,
'arguments': getattr(item, 'arguments', ''),
'context': call_context,
}
logger.debug(f"[TemporalStreamingModel] Starting function call: {item.name}")

Expand All @@ -704,8 +720,24 @@ async def get_response(
output_index = getattr(event, 'output_index', 0)
delta = getattr(event, 'delta', '')

if output_index in function_calls_in_progress:
function_calls_in_progress[output_index]['arguments'] += delta
call_data = function_calls_in_progress.get(output_index)
if call_data is not None:
call_data['arguments'] += delta
call_context = call_data.get('context')
if call_context is not None:
try:
await call_context.stream_update(StreamTaskMessageDelta(
parent_task_message=call_context.task_message,
delta=ToolRequestDelta(
tool_call_id=call_data['call_id'],
name=call_data['name'],
arguments_delta=delta,
type="tool_request",
),
type="delta",
))
except Exception as e:
logger.warning(f"Failed to send tool request delta: {e}")
logger.debug(f"[TemporalStreamingModel] Function call args delta: {delta[:50]}...")

elif isinstance(event, ResponseFunctionCallArgumentsDoneEvent):
Expand Down Expand Up @@ -830,6 +862,40 @@ async def get_response(
)
output_items.append(tool_call)

# Emit the final ToolRequestContent and close the
# per-call streaming context. If the model produced
# invalid JSON args (truncation, hallucination), fall
# back to an empty dict so the streaming layer can
# still persist a message.
call_context = call_data.get('context')
if call_context is not None:
raw_args = call_data['arguments'] or ''
try:
parsed_args = json.loads(raw_args) if raw_args else {}
except json.JSONDecodeError:
logger.warning(
f"Failed to parse tool call arguments for {call_data['name']} "
f"(raw_args_bytes={len(raw_args)})"
)
parsed_args = {}
Comment thread
greptile-apps[bot] marked this conversation as resolved.
try:
await call_context.stream_update(StreamTaskMessageFull(
parent_task_message=call_context.task_message,
content=ToolRequestContent(
author="agent",
tool_call_id=call_data['call_id'],
name=call_data['name'],
arguments=parsed_args,
),
type="full",
))
except Exception as e:
logger.warning(f"Failed to send tool request full update: {e}")
try:
await call_context.close()
except Exception as e:
logger.warning(f"Failed to close tool request context: {e}")

elif isinstance(event, ResponseReasoningSummaryPartAddedEvent):
Comment on lines +894 to 899
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Double-close of streaming context on every successful function call — after ResponseOutputItemDoneEvent closes the context (line 894), call_data['context'] is left non-None. The orphan-cleanup loop (lines 933–940) then iterates function_calls_in_progress.values() and closes every entry whose 'context' is not None — which includes every context that was already closed normally. close() will be called twice for every function call that finished cleanly, potentially publishing a duplicate "stream ended" event to Redis for each tool call in the response.

Suggested change
try:
await call_context.close()
except Exception as e:
logger.warning(f"Failed to close tool request context: {e}")
elif isinstance(event, ResponseReasoningSummaryPartAddedEvent):
try:
await call_context.close()
except Exception as e:
logger.warning(f"Failed to close tool request context: {e}")
finally:
call_data['context'] = None
elif isinstance(event, ResponseReasoningSummaryPartAddedEvent):
Prompt To Fix With AI
This is a comment left during a code review.
Path: src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py
Line: 893-898

Comment:
**Double-close of streaming context on every successful function call** — after `ResponseOutputItemDoneEvent` closes the context (line 894), `call_data['context']` is left non-`None`. The orphan-cleanup loop (lines 933–940) then iterates `function_calls_in_progress.values()` and closes every entry whose `'context'` is not `None` — which includes every context that was already closed normally. `close()` will be called twice for every function call that finished cleanly, potentially publishing a duplicate "stream ended" event to Redis for each tool call in the response.

```suggestion
                                    try:
                                        await call_context.close()
                                    except Exception as e:
                                        logger.warning(f"Failed to close tool request context: {e}")
                                    finally:
                                        call_data['context'] = None

                    elif isinstance(event, ResponseReasoningSummaryPartAddedEvent):
```

How can I resolve this? If you propose a fix, please make it concise.

Fix in Cursor Fix in Claude Code Fix in Codex

# New reasoning part/summary started - reset accumulator
part = getattr(event, 'part', None)
Expand Down Expand Up @@ -863,6 +929,17 @@ async def get_response(
await streaming_context.close()
streaming_context = None

# Defensive: close any function call contexts that didn't see a
# ResponseOutputItemDoneEvent (truncated stream, error mid-call).
for call_data in function_calls_in_progress.values():
call_context = call_data.get('context')
if call_context is not None:
try:
await call_context.close()
except Exception as e:
logger.warning(f"Failed to close orphaned tool request context: {e}")
call_data['context'] = None

# Build the response from output items collected during streaming
# Create output from the items we collected
response_output = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
from openai.types.responses import (
ResponseCompletedEvent,
ResponseTextDeltaEvent,
ResponseOutputItemDoneEvent,
ResponseOutputItemAddedEvent,
ResponseFunctionCallArgumentsDoneEvent,
ResponseReasoningSummaryTextDeltaEvent,
ResponseFunctionCallArgumentsDeltaEvent,
)


Expand Down Expand Up @@ -851,6 +854,197 @@ async def test_missing_task_id_error(self, streaming_model):
)


class TestStreamingModelFunctionCallArgsStreaming:
"""Verify ``ResponseFunctionCallArgumentsDeltaEvent``s are surfaced as
``ToolRequestDelta`` updates and that a final ``ToolRequestContent`` Full is
emitted on ``ResponseOutputItemDoneEvent``.

Without this, write-heavy tools (``write_file``, ``apply_patch``) buffer their
entire argument body inside ``invoke_model_activity`` and the UI sees a
multi-second freeze while the model is actively producing tokens.
"""

@staticmethod
def _build_function_call_stream(arguments_text: str):
"""Construct a streaming event sequence for a single function_call.

Mirrors the production order: Added → N × ArgumentsDelta → ArgumentsDone
→ OutputItemDone → ResponseCompleted. ``spec=...`` makes ``isinstance``
dispatch in production work without triggering pydantic validation.
"""
call_item = MagicMock()
call_item.type = "function_call"
call_item.id = "fc_abc"
call_item.call_id = "call_abc"
call_item.name = "write_file"
call_item.arguments = ""

item_added = MagicMock(spec=ResponseOutputItemAddedEvent)
item_added.item = call_item
item_added.output_index = 0

# Split the argument text into a few chunks to exercise the per-delta loop
chunk_size = max(1, len(arguments_text) // 3) if arguments_text else 1
chunks = [arguments_text[i:i + chunk_size] for i in range(0, len(arguments_text), chunk_size)] or [""]
delta_events = []
for chunk in chunks:
ev = MagicMock(spec=ResponseFunctionCallArgumentsDeltaEvent)
ev.delta = chunk
ev.output_index = 0
delta_events.append(ev)

args_done = MagicMock(spec=ResponseFunctionCallArgumentsDoneEvent)
args_done.arguments = arguments_text
args_done.output_index = 0

item_done = MagicMock(spec=ResponseOutputItemDoneEvent)
item_done.item = call_item
item_done.output_index = 0

completed = MagicMock(spec=ResponseCompletedEvent)
completed.response = MagicMock(output=[], usage=MagicMock(), id=None)

return [item_added, *delta_events, args_done, item_done, completed], chunks

@staticmethod
def _install_real_task_message(mock_adk_streaming, task_id: str):
"""Replace the autouse fixture's MagicMock ``task_message`` with a real
``TaskMessage`` so production's ``StreamTaskMessageDelta(parent_task_message=...)``
construction passes pydantic validation. The default mock works for tests
that only assert on the context's ``__aenter__`` call but breaks tests
that exercise ``stream_update`` end-to-end.
"""
from agentex.types.task_message import TaskMessage
from agentex.types.task_message_content import ToolRequestContent

ctx = mock_adk_streaming.streaming_task_message_context.return_value
ctx.task_message = TaskMessage(
id="msg_test",
task_id=task_id,
content=ToolRequestContent(
author="agent",
tool_call_id="call_abc",
name="write_file",
arguments={},
),
streaming_status="IN_PROGRESS",
)
return ctx

@pytest.mark.asyncio
async def test_function_call_emits_argument_deltas_and_final_full(
self, streaming_model, mock_adk_streaming, _streaming_context_vars, sample_task_id
):
"""A function_call with well-formed JSON args should produce:
(1) one streaming context opened with ``ToolRequestContent`` initial_content,
(2) one ``StreamTaskMessageDelta`` per ``ArgumentsDelta`` event carrying a
``ToolRequestDelta`` with the right ``tool_call_id`` and ``arguments_delta``,
(3) one final ``StreamTaskMessageFull`` with ``ToolRequestContent`` whose
``arguments`` is the parsed JSON dict.
"""
from agentex.types.task_message_delta import ToolRequestDelta
from agentex.types.task_message_update import StreamTaskMessageFull, StreamTaskMessageDelta
from agentex.types.task_message_content import ToolRequestContent

ctx = self._install_real_task_message(mock_adk_streaming, sample_task_id)

args_text = '{"path": "/tmp/foo.txt", "contents": "hello world"}'
events, chunks = self._build_function_call_stream(args_text)

mock_stream = AsyncMock()
mock_stream.__aiter__.return_value = iter(events)
streaming_model.client.responses.create = AsyncMock(return_value=mock_stream)

await streaming_model.get_response(
system_instructions=None,
input="please write foo",
model_settings=ModelSettings(),
tools=[],
output_schema=None,
handoffs=[],
tracing=None,
)

# 1. A streaming context was opened with ToolRequestContent.
opens = [
c for c in mock_adk_streaming.streaming_task_message_context.call_args_list
if isinstance(c.kwargs.get("initial_content"), ToolRequestContent)
]
assert len(opens) == 1, f"expected one ToolRequest context, got {len(opens)}"
initial = opens[0].kwargs["initial_content"]
assert initial.tool_call_id == "call_abc"
assert initial.name == "write_file"

# 2. One StreamTaskMessageDelta(ToolRequestDelta) was streamed per
# ArgumentsDelta event, preserving the delta text exactly.
delta_updates = [
call.args[0] if call.args else call.kwargs.get("update")
for call in ctx.stream_update.call_args_list
if (call.args and isinstance(call.args[0], StreamTaskMessageDelta)
and isinstance(call.args[0].delta, ToolRequestDelta))
]
assert len(delta_updates) == len(chunks)
for update, expected_chunk in zip(delta_updates, chunks):
assert update.delta.tool_call_id == "call_abc"
assert update.delta.name == "write_file"
assert update.delta.arguments_delta == expected_chunk

# 3. A final StreamTaskMessageFull(ToolRequestContent) was streamed with
# parsed args.
full_updates = [
call.args[0] if call.args else call.kwargs.get("update")
for call in ctx.stream_update.call_args_list
if (call.args and isinstance(call.args[0], StreamTaskMessageFull)
and isinstance(call.args[0].content, ToolRequestContent))
]
assert len(full_updates) == 1
final = full_updates[0].content
assert final.tool_call_id == "call_abc"
assert final.name == "write_file"
assert final.arguments == {"path": "/tmp/foo.txt", "contents": "hello world"}

@pytest.mark.asyncio
async def test_function_call_malformed_args_fall_back_to_empty_dict(
self, streaming_model, mock_adk_streaming, _streaming_context_vars, sample_task_id, caplog
):
"""If the model produces invalid JSON for the args, the final
``ToolRequestContent`` should carry ``arguments={}`` and a warning should
be logged. The raw delta stream is preserved either way.
"""
from agentex.types.task_message_update import StreamTaskMessageFull
from agentex.types.task_message_content import ToolRequestContent

ctx = self._install_real_task_message(mock_adk_streaming, sample_task_id)

# Missing closing brace — invalid JSON.
events, _ = self._build_function_call_stream('{"path": "/tmp/foo.txt", "contents":')

mock_stream = AsyncMock()
mock_stream.__aiter__.return_value = iter(events)
streaming_model.client.responses.create = AsyncMock(return_value=mock_stream)

with caplog.at_level("WARNING"):
await streaming_model.get_response(
system_instructions=None,
input="please write foo",
model_settings=ModelSettings(),
tools=[],
output_schema=None,
handoffs=[],
tracing=None,
)

full_updates = [
call.args[0] if call.args else call.kwargs.get("update")
for call in ctx.stream_update.call_args_list
if (call.args and isinstance(call.args[0], StreamTaskMessageFull)
and isinstance(call.args[0].content, ToolRequestContent))
]
assert len(full_updates) == 1
assert full_updates[0].content.arguments == {}
assert any("Failed to parse tool call arguments" in r.getMessage() for r in caplog.records)


class TestStreamingModelUsageResponseIdAndCacheKey:
"""Cover real-Usage capture, real response_id, span emission, and opt-in prompt_cache_key."""

Expand Down
Loading