From 8a0994b18955ad6dff01226cdedde419554d6548 Mon Sep 17 00:00:00 2001 From: Vijay Kalmath Date: Tue, 12 May 2026 11:30:34 -0400 Subject: [PATCH 1/2] feat(streaming): stream tool call argument deltas in TemporalStreamingModel Wire ResponseFunctionCallArgumentsDeltaEvent into the streaming layer introduced in #333, so write-heavy tools (write_file, apply_patch) no longer freeze the UI for the duration of argument generation. The model now opens a per-function-call streaming context with a ToolRequestContent placeholder, emits ToolRequestDelta updates for each argument delta, and finalizes with a StreamTaskMessageFull containing the parsed arguments on ResponseOutputItemDoneEvent. Coalescing and mode dispatch are inherited from the existing streaming infrastructure -- no new flags or surface area. ModelResponse output is unchanged; activity determinism is unaffected. End-of-loop cleanup defensively closes any function-call contexts that didn't see a Done event (truncated stream or mid-stream exception). Adds two tests covering the happy path (well-formed JSON args -> deltas + parsed Full) and the malformed-args fallback (invalid JSON -> empty dict + WARNING log). --- .../models/temporal_streaming_model.py | 90 +++++++- .../tests/test_streaming_model.py | 194 ++++++++++++++++++ 2 files changed, 277 insertions(+), 7 deletions(-) diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py b/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py index 4f18ae379..40cfa9d3e 100644 --- a/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py @@ -1,6 +1,7 @@ """Custom Temporal Model Provider with streaming support for OpenAI agents.""" from __future__ import annotations +import json import uuid from typing import Any, List, Union, Optional, override @@ -60,9 +61,9 @@ from agentex.lib import adk from agentex.lib.utils.logging import make_logger from agentex.lib.core.tracing.tracer import AsyncTracer -from agentex.types.task_message_delta import TextDelta, ReasoningContentDelta, ReasoningSummaryDelta +from agentex.types.task_message_delta import TextDelta, ToolRequestDelta, ReasoningContentDelta, ReasoningSummaryDelta from agentex.types.task_message_update import StreamTaskMessageFull, StreamTaskMessageDelta -from agentex.types.task_message_content import TextContent, ReasoningContent +from agentex.types.task_message_content import TextContent, ReasoningContent, ToolRequestContent from agentex.lib.adk.utils._modules.client import create_async_agentex_client from agentex.lib.core.temporal.plugins.openai_agents.interceptors.context_interceptor import ( streaming_task_id, @@ -678,12 +679,27 @@ async def get_response( streaming_mode=self.streaming_mode, ).__aenter__() elif item and getattr(item, 'type', None) == 'function_call': - # Track the function call being streamed + # Open a streaming context per function call so argument + # deltas can be published incrementally. Coalescing and + # mode dispatch are handled by the streaming layer. + call_id = getattr(item, 'call_id', '') + tool_name = getattr(item, 'name', '') + call_context = await adk.streaming.streaming_task_message_context( + task_id=task_id, + initial_content=ToolRequestContent( + author="agent", + tool_call_id=call_id, + name=tool_name, + arguments={}, + ), + streaming_mode=self.streaming_mode, + ).__aenter__() function_calls_in_progress[output_index] = { 'id': getattr(item, 'id', ''), - 'call_id': getattr(item, 'call_id', ''), - 'name': getattr(item, 'name', ''), + 'call_id': call_id, + 'name': tool_name, 'arguments': getattr(item, 'arguments', ''), + 'context': call_context, } logger.debug(f"[TemporalStreamingModel] Starting function call: {item.name}") @@ -704,8 +720,24 @@ async def get_response( output_index = getattr(event, 'output_index', 0) delta = getattr(event, 'delta', '') - if output_index in function_calls_in_progress: - function_calls_in_progress[output_index]['arguments'] += delta + call_data = function_calls_in_progress.get(output_index) + if call_data is not None: + call_data['arguments'] += delta + call_context = call_data.get('context') + if call_context is not None: + try: + await call_context.stream_update(StreamTaskMessageDelta( + parent_task_message=call_context.task_message, + delta=ToolRequestDelta( + tool_call_id=call_data['call_id'], + name=call_data['name'], + arguments_delta=delta, + type="tool_request", + ), + type="delta", + )) + except Exception as e: + logger.warning(f"Failed to send tool request delta: {e}") logger.debug(f"[TemporalStreamingModel] Function call args delta: {delta[:50]}...") elif isinstance(event, ResponseFunctionCallArgumentsDoneEvent): @@ -830,6 +862,39 @@ async def get_response( ) output_items.append(tool_call) + # Emit the final ToolRequestContent and close the + # per-call streaming context. If the model produced + # invalid JSON args (truncation, hallucination), fall + # back to an empty dict so the streaming layer can + # still persist a message. + call_context = call_data.get('context') + if call_context is not None: + raw_args = call_data['arguments'] or '' + try: + parsed_args = json.loads(raw_args) if raw_args else {} + except json.JSONDecodeError: + logger.warning( + f"Failed to parse tool call arguments for {call_data['name']}: {raw_args[:200]}" + ) + parsed_args = {} + try: + await call_context.stream_update(StreamTaskMessageFull( + parent_task_message=call_context.task_message, + content=ToolRequestContent( + author="agent", + tool_call_id=call_data['call_id'], + name=call_data['name'], + arguments=parsed_args, + ), + type="full", + )) + except Exception as e: + logger.warning(f"Failed to send tool request full update: {e}") + try: + await call_context.close() + except Exception as e: + logger.warning(f"Failed to close tool request context: {e}") + elif isinstance(event, ResponseReasoningSummaryPartAddedEvent): # New reasoning part/summary started - reset accumulator part = getattr(event, 'part', None) @@ -863,6 +928,17 @@ async def get_response( await streaming_context.close() streaming_context = None + # Defensive: close any function call contexts that didn't see a + # ResponseOutputItemDoneEvent (truncated stream, error mid-call). + for call_data in function_calls_in_progress.values(): + call_context = call_data.get('context') + if call_context is not None: + try: + await call_context.close() + except Exception as e: + logger.warning(f"Failed to close orphaned tool request context: {e}") + call_data['context'] = None + # Build the response from output items collected during streaming # Create output from the items we collected response_output = [] diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_streaming_model.py b/src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_streaming_model.py index 97dda0e61..26c0b7c4b 100644 --- a/src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_streaming_model.py +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_streaming_model.py @@ -12,8 +12,11 @@ from openai.types.responses import ( ResponseCompletedEvent, ResponseTextDeltaEvent, + ResponseOutputItemDoneEvent, ResponseOutputItemAddedEvent, + ResponseFunctionCallArgumentsDoneEvent, ResponseReasoningSummaryTextDeltaEvent, + ResponseFunctionCallArgumentsDeltaEvent, ) @@ -851,6 +854,197 @@ async def test_missing_task_id_error(self, streaming_model): ) +class TestStreamingModelFunctionCallArgsStreaming: + """Verify ``ResponseFunctionCallArgumentsDeltaEvent``s are surfaced as + ``ToolRequestDelta`` updates and that a final ``ToolRequestContent`` Full is + emitted on ``ResponseOutputItemDoneEvent``. + + Without this, write-heavy tools (``write_file``, ``apply_patch``) buffer their + entire argument body inside ``invoke_model_activity`` and the UI sees a + multi-second freeze while the model is actively producing tokens. + """ + + @staticmethod + def _build_function_call_stream(arguments_text: str): + """Construct a streaming event sequence for a single function_call. + + Mirrors the production order: Added → N × ArgumentsDelta → ArgumentsDone + → OutputItemDone → ResponseCompleted. ``spec=...`` makes ``isinstance`` + dispatch in production work without triggering pydantic validation. + """ + call_item = MagicMock() + call_item.type = "function_call" + call_item.id = "fc_abc" + call_item.call_id = "call_abc" + call_item.name = "write_file" + call_item.arguments = "" + + item_added = MagicMock(spec=ResponseOutputItemAddedEvent) + item_added.item = call_item + item_added.output_index = 0 + + # Split the argument text into a few chunks to exercise the per-delta loop + chunk_size = max(1, len(arguments_text) // 3) if arguments_text else 1 + chunks = [arguments_text[i:i + chunk_size] for i in range(0, len(arguments_text), chunk_size)] or [""] + delta_events = [] + for chunk in chunks: + ev = MagicMock(spec=ResponseFunctionCallArgumentsDeltaEvent) + ev.delta = chunk + ev.output_index = 0 + delta_events.append(ev) + + args_done = MagicMock(spec=ResponseFunctionCallArgumentsDoneEvent) + args_done.arguments = arguments_text + args_done.output_index = 0 + + item_done = MagicMock(spec=ResponseOutputItemDoneEvent) + item_done.item = call_item + item_done.output_index = 0 + + completed = MagicMock(spec=ResponseCompletedEvent) + completed.response = MagicMock(output=[], usage=MagicMock(), id=None) + + return [item_added, *delta_events, args_done, item_done, completed], chunks + + @staticmethod + def _install_real_task_message(mock_adk_streaming, task_id: str): + """Replace the autouse fixture's MagicMock ``task_message`` with a real + ``TaskMessage`` so production's ``StreamTaskMessageDelta(parent_task_message=...)`` + construction passes pydantic validation. The default mock works for tests + that only assert on the context's ``__aenter__`` call but breaks tests + that exercise ``stream_update`` end-to-end. + """ + from agentex.types.task_message import TaskMessage + from agentex.types.task_message_content import ToolRequestContent + + ctx = mock_adk_streaming.streaming_task_message_context.return_value + ctx.task_message = TaskMessage( + id="msg_test", + task_id=task_id, + content=ToolRequestContent( + author="agent", + tool_call_id="call_abc", + name="write_file", + arguments={}, + ), + streaming_status="IN_PROGRESS", + ) + return ctx + + @pytest.mark.asyncio + async def test_function_call_emits_argument_deltas_and_final_full( + self, streaming_model, mock_adk_streaming, _streaming_context_vars, sample_task_id + ): + """A function_call with well-formed JSON args should produce: + (1) one streaming context opened with ``ToolRequestContent`` initial_content, + (2) one ``StreamTaskMessageDelta`` per ``ArgumentsDelta`` event carrying a + ``ToolRequestDelta`` with the right ``tool_call_id`` and ``arguments_delta``, + (3) one final ``StreamTaskMessageFull`` with ``ToolRequestContent`` whose + ``arguments`` is the parsed JSON dict. + """ + from agentex.types.task_message_delta import ToolRequestDelta + from agentex.types.task_message_update import StreamTaskMessageFull, StreamTaskMessageDelta + from agentex.types.task_message_content import ToolRequestContent + + ctx = self._install_real_task_message(mock_adk_streaming, sample_task_id) + + args_text = '{"path": "/tmp/foo.txt", "contents": "hello world"}' + events, chunks = self._build_function_call_stream(args_text) + + mock_stream = AsyncMock() + mock_stream.__aiter__.return_value = iter(events) + streaming_model.client.responses.create = AsyncMock(return_value=mock_stream) + + await streaming_model.get_response( + system_instructions=None, + input="please write foo", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + ) + + # 1. A streaming context was opened with ToolRequestContent. + opens = [ + c for c in mock_adk_streaming.streaming_task_message_context.call_args_list + if isinstance(c.kwargs.get("initial_content"), ToolRequestContent) + ] + assert len(opens) == 1, f"expected one ToolRequest context, got {len(opens)}" + initial = opens[0].kwargs["initial_content"] + assert initial.tool_call_id == "call_abc" + assert initial.name == "write_file" + + # 2. One StreamTaskMessageDelta(ToolRequestDelta) was streamed per + # ArgumentsDelta event, preserving the delta text exactly. + delta_updates = [ + call.args[0] if call.args else call.kwargs.get("update") + for call in ctx.stream_update.call_args_list + if (call.args and isinstance(call.args[0], StreamTaskMessageDelta) + and isinstance(call.args[0].delta, ToolRequestDelta)) + ] + assert len(delta_updates) == len(chunks) + for update, expected_chunk in zip(delta_updates, chunks): + assert update.delta.tool_call_id == "call_abc" + assert update.delta.name == "write_file" + assert update.delta.arguments_delta == expected_chunk + + # 3. A final StreamTaskMessageFull(ToolRequestContent) was streamed with + # parsed args. + full_updates = [ + call.args[0] if call.args else call.kwargs.get("update") + for call in ctx.stream_update.call_args_list + if (call.args and isinstance(call.args[0], StreamTaskMessageFull) + and isinstance(call.args[0].content, ToolRequestContent)) + ] + assert len(full_updates) == 1 + final = full_updates[0].content + assert final.tool_call_id == "call_abc" + assert final.name == "write_file" + assert final.arguments == {"path": "/tmp/foo.txt", "contents": "hello world"} + + @pytest.mark.asyncio + async def test_function_call_malformed_args_fall_back_to_empty_dict( + self, streaming_model, mock_adk_streaming, _streaming_context_vars, sample_task_id, caplog + ): + """If the model produces invalid JSON for the args, the final + ``ToolRequestContent`` should carry ``arguments={}`` and a warning should + be logged. The raw delta stream is preserved either way. + """ + from agentex.types.task_message_update import StreamTaskMessageFull + from agentex.types.task_message_content import ToolRequestContent + + ctx = self._install_real_task_message(mock_adk_streaming, sample_task_id) + + # Missing closing brace — invalid JSON. + events, _ = self._build_function_call_stream('{"path": "/tmp/foo.txt", "contents":') + + mock_stream = AsyncMock() + mock_stream.__aiter__.return_value = iter(events) + streaming_model.client.responses.create = AsyncMock(return_value=mock_stream) + + with caplog.at_level("WARNING"): + await streaming_model.get_response( + system_instructions=None, + input="please write foo", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + ) + + full_updates = [ + call.args[0] if call.args else call.kwargs.get("update") + for call in ctx.stream_update.call_args_list + if (call.args and isinstance(call.args[0], StreamTaskMessageFull) + and isinstance(call.args[0].content, ToolRequestContent)) + ] + assert len(full_updates) == 1 + assert full_updates[0].content.arguments == {} + assert any("Failed to parse tool call arguments" in r.getMessage() for r in caplog.records) + + class TestStreamingModelUsageResponseIdAndCacheKey: """Cover real-Usage capture, real response_id, span emission, and opt-in prompt_cache_key.""" From 18ea7af15a80c473328d184ed503f295aaeec3e5 Mon Sep 17 00:00:00 2001 From: Vijay Kalmath Date: Tue, 12 May 2026 11:52:36 -0400 Subject: [PATCH 2/2] fix(streaming): drop raw tool args from WARNING log Logging raw_args[:200] could leak partial file contents, PII, or secrets from write_file / apply_patch arguments into production log pipelines. Switch to logging only bounded metadata (tool name + raw arg byte count). The existing malformed-args test still passes since it asserts on the "Failed to parse tool call arguments" prefix, which is preserved. --- .../plugins/openai_agents/models/temporal_streaming_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py b/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py index 40cfa9d3e..727448706 100644 --- a/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py @@ -874,7 +874,8 @@ async def get_response( parsed_args = json.loads(raw_args) if raw_args else {} except json.JSONDecodeError: logger.warning( - f"Failed to parse tool call arguments for {call_data['name']}: {raw_args[:200]}" + f"Failed to parse tool call arguments for {call_data['name']} " + f"(raw_args_bytes={len(raw_args)})" ) parsed_args = {} try: