Skip to content

Commit 8a0994b

Browse files
committed
feat(streaming): stream tool call argument deltas in TemporalStreamingModel
Wire ResponseFunctionCallArgumentsDeltaEvent into the streaming layer introduced in #333, so write-heavy tools (write_file, apply_patch) no longer freeze the UI for the duration of argument generation. The model now opens a per-function-call streaming context with a ToolRequestContent placeholder, emits ToolRequestDelta updates for each argument delta, and finalizes with a StreamTaskMessageFull containing the parsed arguments on ResponseOutputItemDoneEvent. Coalescing and mode dispatch are inherited from the existing streaming infrastructure -- no new flags or surface area. ModelResponse output is unchanged; activity determinism is unaffected. End-of-loop cleanup defensively closes any function-call contexts that didn't see a Done event (truncated stream or mid-stream exception). Adds two tests covering the happy path (well-formed JSON args -> deltas + parsed Full) and the malformed-args fallback (invalid JSON -> empty dict + WARNING log).
1 parent 6092595 commit 8a0994b

2 files changed

Lines changed: 277 additions & 7 deletions

File tree

src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py

Lines changed: 83 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Custom Temporal Model Provider with streaming support for OpenAI agents."""
22
from __future__ import annotations
33

4+
import json
45
import uuid
56
from typing import Any, List, Union, Optional, override
67

@@ -60,9 +61,9 @@
6061
from agentex.lib import adk
6162
from agentex.lib.utils.logging import make_logger
6263
from agentex.lib.core.tracing.tracer import AsyncTracer
63-
from agentex.types.task_message_delta import TextDelta, ReasoningContentDelta, ReasoningSummaryDelta
64+
from agentex.types.task_message_delta import TextDelta, ToolRequestDelta, ReasoningContentDelta, ReasoningSummaryDelta
6465
from agentex.types.task_message_update import StreamTaskMessageFull, StreamTaskMessageDelta
65-
from agentex.types.task_message_content import TextContent, ReasoningContent
66+
from agentex.types.task_message_content import TextContent, ReasoningContent, ToolRequestContent
6667
from agentex.lib.adk.utils._modules.client import create_async_agentex_client
6768
from agentex.lib.core.temporal.plugins.openai_agents.interceptors.context_interceptor import (
6869
streaming_task_id,
@@ -678,12 +679,27 @@ async def get_response(
678679
streaming_mode=self.streaming_mode,
679680
).__aenter__()
680681
elif item and getattr(item, 'type', None) == 'function_call':
681-
# Track the function call being streamed
682+
# Open a streaming context per function call so argument
683+
# deltas can be published incrementally. Coalescing and
684+
# mode dispatch are handled by the streaming layer.
685+
call_id = getattr(item, 'call_id', '')
686+
tool_name = getattr(item, 'name', '')
687+
call_context = await adk.streaming.streaming_task_message_context(
688+
task_id=task_id,
689+
initial_content=ToolRequestContent(
690+
author="agent",
691+
tool_call_id=call_id,
692+
name=tool_name,
693+
arguments={},
694+
),
695+
streaming_mode=self.streaming_mode,
696+
).__aenter__()
682697
function_calls_in_progress[output_index] = {
683698
'id': getattr(item, 'id', ''),
684-
'call_id': getattr(item, 'call_id', ''),
685-
'name': getattr(item, 'name', ''),
699+
'call_id': call_id,
700+
'name': tool_name,
686701
'arguments': getattr(item, 'arguments', ''),
702+
'context': call_context,
687703
}
688704
logger.debug(f"[TemporalStreamingModel] Starting function call: {item.name}")
689705

@@ -704,8 +720,24 @@ async def get_response(
704720
output_index = getattr(event, 'output_index', 0)
705721
delta = getattr(event, 'delta', '')
706722

707-
if output_index in function_calls_in_progress:
708-
function_calls_in_progress[output_index]['arguments'] += delta
723+
call_data = function_calls_in_progress.get(output_index)
724+
if call_data is not None:
725+
call_data['arguments'] += delta
726+
call_context = call_data.get('context')
727+
if call_context is not None:
728+
try:
729+
await call_context.stream_update(StreamTaskMessageDelta(
730+
parent_task_message=call_context.task_message,
731+
delta=ToolRequestDelta(
732+
tool_call_id=call_data['call_id'],
733+
name=call_data['name'],
734+
arguments_delta=delta,
735+
type="tool_request",
736+
),
737+
type="delta",
738+
))
739+
except Exception as e:
740+
logger.warning(f"Failed to send tool request delta: {e}")
709741
logger.debug(f"[TemporalStreamingModel] Function call args delta: {delta[:50]}...")
710742

711743
elif isinstance(event, ResponseFunctionCallArgumentsDoneEvent):
@@ -830,6 +862,39 @@ async def get_response(
830862
)
831863
output_items.append(tool_call)
832864

865+
# Emit the final ToolRequestContent and close the
866+
# per-call streaming context. If the model produced
867+
# invalid JSON args (truncation, hallucination), fall
868+
# back to an empty dict so the streaming layer can
869+
# still persist a message.
870+
call_context = call_data.get('context')
871+
if call_context is not None:
872+
raw_args = call_data['arguments'] or ''
873+
try:
874+
parsed_args = json.loads(raw_args) if raw_args else {}
875+
except json.JSONDecodeError:
876+
logger.warning(
877+
f"Failed to parse tool call arguments for {call_data['name']}: {raw_args[:200]}"
878+
)
879+
parsed_args = {}
880+
try:
881+
await call_context.stream_update(StreamTaskMessageFull(
882+
parent_task_message=call_context.task_message,
883+
content=ToolRequestContent(
884+
author="agent",
885+
tool_call_id=call_data['call_id'],
886+
name=call_data['name'],
887+
arguments=parsed_args,
888+
),
889+
type="full",
890+
))
891+
except Exception as e:
892+
logger.warning(f"Failed to send tool request full update: {e}")
893+
try:
894+
await call_context.close()
895+
except Exception as e:
896+
logger.warning(f"Failed to close tool request context: {e}")
897+
833898
elif isinstance(event, ResponseReasoningSummaryPartAddedEvent):
834899
# New reasoning part/summary started - reset accumulator
835900
part = getattr(event, 'part', None)
@@ -863,6 +928,17 @@ async def get_response(
863928
await streaming_context.close()
864929
streaming_context = None
865930

931+
# Defensive: close any function call contexts that didn't see a
932+
# ResponseOutputItemDoneEvent (truncated stream, error mid-call).
933+
for call_data in function_calls_in_progress.values():
934+
call_context = call_data.get('context')
935+
if call_context is not None:
936+
try:
937+
await call_context.close()
938+
except Exception as e:
939+
logger.warning(f"Failed to close orphaned tool request context: {e}")
940+
call_data['context'] = None
941+
866942
# Build the response from output items collected during streaming
867943
# Create output from the items we collected
868944
response_output = []

src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_streaming_model.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
from openai.types.responses import (
1313
ResponseCompletedEvent,
1414
ResponseTextDeltaEvent,
15+
ResponseOutputItemDoneEvent,
1516
ResponseOutputItemAddedEvent,
17+
ResponseFunctionCallArgumentsDoneEvent,
1618
ResponseReasoningSummaryTextDeltaEvent,
19+
ResponseFunctionCallArgumentsDeltaEvent,
1720
)
1821

1922

@@ -851,6 +854,197 @@ async def test_missing_task_id_error(self, streaming_model):
851854
)
852855

853856

857+
class TestStreamingModelFunctionCallArgsStreaming:
858+
"""Verify ``ResponseFunctionCallArgumentsDeltaEvent``s are surfaced as
859+
``ToolRequestDelta`` updates and that a final ``ToolRequestContent`` Full is
860+
emitted on ``ResponseOutputItemDoneEvent``.
861+
862+
Without this, write-heavy tools (``write_file``, ``apply_patch``) buffer their
863+
entire argument body inside ``invoke_model_activity`` and the UI sees a
864+
multi-second freeze while the model is actively producing tokens.
865+
"""
866+
867+
@staticmethod
868+
def _build_function_call_stream(arguments_text: str):
869+
"""Construct a streaming event sequence for a single function_call.
870+
871+
Mirrors the production order: Added → N × ArgumentsDelta → ArgumentsDone
872+
→ OutputItemDone → ResponseCompleted. ``spec=...`` makes ``isinstance``
873+
dispatch in production work without triggering pydantic validation.
874+
"""
875+
call_item = MagicMock()
876+
call_item.type = "function_call"
877+
call_item.id = "fc_abc"
878+
call_item.call_id = "call_abc"
879+
call_item.name = "write_file"
880+
call_item.arguments = ""
881+
882+
item_added = MagicMock(spec=ResponseOutputItemAddedEvent)
883+
item_added.item = call_item
884+
item_added.output_index = 0
885+
886+
# Split the argument text into a few chunks to exercise the per-delta loop
887+
chunk_size = max(1, len(arguments_text) // 3) if arguments_text else 1
888+
chunks = [arguments_text[i:i + chunk_size] for i in range(0, len(arguments_text), chunk_size)] or [""]
889+
delta_events = []
890+
for chunk in chunks:
891+
ev = MagicMock(spec=ResponseFunctionCallArgumentsDeltaEvent)
892+
ev.delta = chunk
893+
ev.output_index = 0
894+
delta_events.append(ev)
895+
896+
args_done = MagicMock(spec=ResponseFunctionCallArgumentsDoneEvent)
897+
args_done.arguments = arguments_text
898+
args_done.output_index = 0
899+
900+
item_done = MagicMock(spec=ResponseOutputItemDoneEvent)
901+
item_done.item = call_item
902+
item_done.output_index = 0
903+
904+
completed = MagicMock(spec=ResponseCompletedEvent)
905+
completed.response = MagicMock(output=[], usage=MagicMock(), id=None)
906+
907+
return [item_added, *delta_events, args_done, item_done, completed], chunks
908+
909+
@staticmethod
910+
def _install_real_task_message(mock_adk_streaming, task_id: str):
911+
"""Replace the autouse fixture's MagicMock ``task_message`` with a real
912+
``TaskMessage`` so production's ``StreamTaskMessageDelta(parent_task_message=...)``
913+
construction passes pydantic validation. The default mock works for tests
914+
that only assert on the context's ``__aenter__`` call but breaks tests
915+
that exercise ``stream_update`` end-to-end.
916+
"""
917+
from agentex.types.task_message import TaskMessage
918+
from agentex.types.task_message_content import ToolRequestContent
919+
920+
ctx = mock_adk_streaming.streaming_task_message_context.return_value
921+
ctx.task_message = TaskMessage(
922+
id="msg_test",
923+
task_id=task_id,
924+
content=ToolRequestContent(
925+
author="agent",
926+
tool_call_id="call_abc",
927+
name="write_file",
928+
arguments={},
929+
),
930+
streaming_status="IN_PROGRESS",
931+
)
932+
return ctx
933+
934+
@pytest.mark.asyncio
935+
async def test_function_call_emits_argument_deltas_and_final_full(
936+
self, streaming_model, mock_adk_streaming, _streaming_context_vars, sample_task_id
937+
):
938+
"""A function_call with well-formed JSON args should produce:
939+
(1) one streaming context opened with ``ToolRequestContent`` initial_content,
940+
(2) one ``StreamTaskMessageDelta`` per ``ArgumentsDelta`` event carrying a
941+
``ToolRequestDelta`` with the right ``tool_call_id`` and ``arguments_delta``,
942+
(3) one final ``StreamTaskMessageFull`` with ``ToolRequestContent`` whose
943+
``arguments`` is the parsed JSON dict.
944+
"""
945+
from agentex.types.task_message_delta import ToolRequestDelta
946+
from agentex.types.task_message_update import StreamTaskMessageFull, StreamTaskMessageDelta
947+
from agentex.types.task_message_content import ToolRequestContent
948+
949+
ctx = self._install_real_task_message(mock_adk_streaming, sample_task_id)
950+
951+
args_text = '{"path": "/tmp/foo.txt", "contents": "hello world"}'
952+
events, chunks = self._build_function_call_stream(args_text)
953+
954+
mock_stream = AsyncMock()
955+
mock_stream.__aiter__.return_value = iter(events)
956+
streaming_model.client.responses.create = AsyncMock(return_value=mock_stream)
957+
958+
await streaming_model.get_response(
959+
system_instructions=None,
960+
input="please write foo",
961+
model_settings=ModelSettings(),
962+
tools=[],
963+
output_schema=None,
964+
handoffs=[],
965+
tracing=None,
966+
)
967+
968+
# 1. A streaming context was opened with ToolRequestContent.
969+
opens = [
970+
c for c in mock_adk_streaming.streaming_task_message_context.call_args_list
971+
if isinstance(c.kwargs.get("initial_content"), ToolRequestContent)
972+
]
973+
assert len(opens) == 1, f"expected one ToolRequest context, got {len(opens)}"
974+
initial = opens[0].kwargs["initial_content"]
975+
assert initial.tool_call_id == "call_abc"
976+
assert initial.name == "write_file"
977+
978+
# 2. One StreamTaskMessageDelta(ToolRequestDelta) was streamed per
979+
# ArgumentsDelta event, preserving the delta text exactly.
980+
delta_updates = [
981+
call.args[0] if call.args else call.kwargs.get("update")
982+
for call in ctx.stream_update.call_args_list
983+
if (call.args and isinstance(call.args[0], StreamTaskMessageDelta)
984+
and isinstance(call.args[0].delta, ToolRequestDelta))
985+
]
986+
assert len(delta_updates) == len(chunks)
987+
for update, expected_chunk in zip(delta_updates, chunks):
988+
assert update.delta.tool_call_id == "call_abc"
989+
assert update.delta.name == "write_file"
990+
assert update.delta.arguments_delta == expected_chunk
991+
992+
# 3. A final StreamTaskMessageFull(ToolRequestContent) was streamed with
993+
# parsed args.
994+
full_updates = [
995+
call.args[0] if call.args else call.kwargs.get("update")
996+
for call in ctx.stream_update.call_args_list
997+
if (call.args and isinstance(call.args[0], StreamTaskMessageFull)
998+
and isinstance(call.args[0].content, ToolRequestContent))
999+
]
1000+
assert len(full_updates) == 1
1001+
final = full_updates[0].content
1002+
assert final.tool_call_id == "call_abc"
1003+
assert final.name == "write_file"
1004+
assert final.arguments == {"path": "/tmp/foo.txt", "contents": "hello world"}
1005+
1006+
@pytest.mark.asyncio
1007+
async def test_function_call_malformed_args_fall_back_to_empty_dict(
1008+
self, streaming_model, mock_adk_streaming, _streaming_context_vars, sample_task_id, caplog
1009+
):
1010+
"""If the model produces invalid JSON for the args, the final
1011+
``ToolRequestContent`` should carry ``arguments={}`` and a warning should
1012+
be logged. The raw delta stream is preserved either way.
1013+
"""
1014+
from agentex.types.task_message_update import StreamTaskMessageFull
1015+
from agentex.types.task_message_content import ToolRequestContent
1016+
1017+
ctx = self._install_real_task_message(mock_adk_streaming, sample_task_id)
1018+
1019+
# Missing closing brace — invalid JSON.
1020+
events, _ = self._build_function_call_stream('{"path": "/tmp/foo.txt", "contents":')
1021+
1022+
mock_stream = AsyncMock()
1023+
mock_stream.__aiter__.return_value = iter(events)
1024+
streaming_model.client.responses.create = AsyncMock(return_value=mock_stream)
1025+
1026+
with caplog.at_level("WARNING"):
1027+
await streaming_model.get_response(
1028+
system_instructions=None,
1029+
input="please write foo",
1030+
model_settings=ModelSettings(),
1031+
tools=[],
1032+
output_schema=None,
1033+
handoffs=[],
1034+
tracing=None,
1035+
)
1036+
1037+
full_updates = [
1038+
call.args[0] if call.args else call.kwargs.get("update")
1039+
for call in ctx.stream_update.call_args_list
1040+
if (call.args and isinstance(call.args[0], StreamTaskMessageFull)
1041+
and isinstance(call.args[0].content, ToolRequestContent))
1042+
]
1043+
assert len(full_updates) == 1
1044+
assert full_updates[0].content.arguments == {}
1045+
assert any("Failed to parse tool call arguments" in r.getMessage() for r in caplog.records)
1046+
1047+
8541048
class TestStreamingModelUsageResponseIdAndCacheKey:
8551049
"""Cover real-Usage capture, real response_id, span emission, and opt-in prompt_cache_key."""
8561050

0 commit comments

Comments
 (0)