|
12 | 12 | from openai.types.responses import ( |
13 | 13 | ResponseCompletedEvent, |
14 | 14 | ResponseTextDeltaEvent, |
| 15 | + ResponseOutputItemDoneEvent, |
15 | 16 | ResponseOutputItemAddedEvent, |
| 17 | + ResponseFunctionCallArgumentsDoneEvent, |
16 | 18 | ResponseReasoningSummaryTextDeltaEvent, |
| 19 | + ResponseFunctionCallArgumentsDeltaEvent, |
17 | 20 | ) |
18 | 21 |
|
19 | 22 |
|
@@ -851,6 +854,197 @@ async def test_missing_task_id_error(self, streaming_model): |
851 | 854 | ) |
852 | 855 |
|
853 | 856 |
|
| 857 | +class TestStreamingModelFunctionCallArgsStreaming: |
| 858 | + """Verify ``ResponseFunctionCallArgumentsDeltaEvent``s are surfaced as |
| 859 | + ``ToolRequestDelta`` updates and that a final ``ToolRequestContent`` Full is |
| 860 | + emitted on ``ResponseOutputItemDoneEvent``. |
| 861 | +
|
| 862 | + Without this, write-heavy tools (``write_file``, ``apply_patch``) buffer their |
| 863 | + entire argument body inside ``invoke_model_activity`` and the UI sees a |
| 864 | + multi-second freeze while the model is actively producing tokens. |
| 865 | + """ |
| 866 | + |
| 867 | + @staticmethod |
| 868 | + def _build_function_call_stream(arguments_text: str): |
| 869 | + """Construct a streaming event sequence for a single function_call. |
| 870 | +
|
| 871 | + Mirrors the production order: Added → N × ArgumentsDelta → ArgumentsDone |
| 872 | + → OutputItemDone → ResponseCompleted. ``spec=...`` makes ``isinstance`` |
| 873 | + dispatch in production work without triggering pydantic validation. |
| 874 | + """ |
| 875 | + call_item = MagicMock() |
| 876 | + call_item.type = "function_call" |
| 877 | + call_item.id = "fc_abc" |
| 878 | + call_item.call_id = "call_abc" |
| 879 | + call_item.name = "write_file" |
| 880 | + call_item.arguments = "" |
| 881 | + |
| 882 | + item_added = MagicMock(spec=ResponseOutputItemAddedEvent) |
| 883 | + item_added.item = call_item |
| 884 | + item_added.output_index = 0 |
| 885 | + |
| 886 | + # Split the argument text into a few chunks to exercise the per-delta loop |
| 887 | + chunk_size = max(1, len(arguments_text) // 3) if arguments_text else 1 |
| 888 | + chunks = [arguments_text[i:i + chunk_size] for i in range(0, len(arguments_text), chunk_size)] or [""] |
| 889 | + delta_events = [] |
| 890 | + for chunk in chunks: |
| 891 | + ev = MagicMock(spec=ResponseFunctionCallArgumentsDeltaEvent) |
| 892 | + ev.delta = chunk |
| 893 | + ev.output_index = 0 |
| 894 | + delta_events.append(ev) |
| 895 | + |
| 896 | + args_done = MagicMock(spec=ResponseFunctionCallArgumentsDoneEvent) |
| 897 | + args_done.arguments = arguments_text |
| 898 | + args_done.output_index = 0 |
| 899 | + |
| 900 | + item_done = MagicMock(spec=ResponseOutputItemDoneEvent) |
| 901 | + item_done.item = call_item |
| 902 | + item_done.output_index = 0 |
| 903 | + |
| 904 | + completed = MagicMock(spec=ResponseCompletedEvent) |
| 905 | + completed.response = MagicMock(output=[], usage=MagicMock(), id=None) |
| 906 | + |
| 907 | + return [item_added, *delta_events, args_done, item_done, completed], chunks |
| 908 | + |
| 909 | + @staticmethod |
| 910 | + def _install_real_task_message(mock_adk_streaming, task_id: str): |
| 911 | + """Replace the autouse fixture's MagicMock ``task_message`` with a real |
| 912 | + ``TaskMessage`` so production's ``StreamTaskMessageDelta(parent_task_message=...)`` |
| 913 | + construction passes pydantic validation. The default mock works for tests |
| 914 | + that only assert on the context's ``__aenter__`` call but breaks tests |
| 915 | + that exercise ``stream_update`` end-to-end. |
| 916 | + """ |
| 917 | + from agentex.types.task_message import TaskMessage |
| 918 | + from agentex.types.task_message_content import ToolRequestContent |
| 919 | + |
| 920 | + ctx = mock_adk_streaming.streaming_task_message_context.return_value |
| 921 | + ctx.task_message = TaskMessage( |
| 922 | + id="msg_test", |
| 923 | + task_id=task_id, |
| 924 | + content=ToolRequestContent( |
| 925 | + author="agent", |
| 926 | + tool_call_id="call_abc", |
| 927 | + name="write_file", |
| 928 | + arguments={}, |
| 929 | + ), |
| 930 | + streaming_status="IN_PROGRESS", |
| 931 | + ) |
| 932 | + return ctx |
| 933 | + |
| 934 | + @pytest.mark.asyncio |
| 935 | + async def test_function_call_emits_argument_deltas_and_final_full( |
| 936 | + self, streaming_model, mock_adk_streaming, _streaming_context_vars, sample_task_id |
| 937 | + ): |
| 938 | + """A function_call with well-formed JSON args should produce: |
| 939 | + (1) one streaming context opened with ``ToolRequestContent`` initial_content, |
| 940 | + (2) one ``StreamTaskMessageDelta`` per ``ArgumentsDelta`` event carrying a |
| 941 | + ``ToolRequestDelta`` with the right ``tool_call_id`` and ``arguments_delta``, |
| 942 | + (3) one final ``StreamTaskMessageFull`` with ``ToolRequestContent`` whose |
| 943 | + ``arguments`` is the parsed JSON dict. |
| 944 | + """ |
| 945 | + from agentex.types.task_message_delta import ToolRequestDelta |
| 946 | + from agentex.types.task_message_update import StreamTaskMessageFull, StreamTaskMessageDelta |
| 947 | + from agentex.types.task_message_content import ToolRequestContent |
| 948 | + |
| 949 | + ctx = self._install_real_task_message(mock_adk_streaming, sample_task_id) |
| 950 | + |
| 951 | + args_text = '{"path": "/tmp/foo.txt", "contents": "hello world"}' |
| 952 | + events, chunks = self._build_function_call_stream(args_text) |
| 953 | + |
| 954 | + mock_stream = AsyncMock() |
| 955 | + mock_stream.__aiter__.return_value = iter(events) |
| 956 | + streaming_model.client.responses.create = AsyncMock(return_value=mock_stream) |
| 957 | + |
| 958 | + await streaming_model.get_response( |
| 959 | + system_instructions=None, |
| 960 | + input="please write foo", |
| 961 | + model_settings=ModelSettings(), |
| 962 | + tools=[], |
| 963 | + output_schema=None, |
| 964 | + handoffs=[], |
| 965 | + tracing=None, |
| 966 | + ) |
| 967 | + |
| 968 | + # 1. A streaming context was opened with ToolRequestContent. |
| 969 | + opens = [ |
| 970 | + c for c in mock_adk_streaming.streaming_task_message_context.call_args_list |
| 971 | + if isinstance(c.kwargs.get("initial_content"), ToolRequestContent) |
| 972 | + ] |
| 973 | + assert len(opens) == 1, f"expected one ToolRequest context, got {len(opens)}" |
| 974 | + initial = opens[0].kwargs["initial_content"] |
| 975 | + assert initial.tool_call_id == "call_abc" |
| 976 | + assert initial.name == "write_file" |
| 977 | + |
| 978 | + # 2. One StreamTaskMessageDelta(ToolRequestDelta) was streamed per |
| 979 | + # ArgumentsDelta event, preserving the delta text exactly. |
| 980 | + delta_updates = [ |
| 981 | + call.args[0] if call.args else call.kwargs.get("update") |
| 982 | + for call in ctx.stream_update.call_args_list |
| 983 | + if (call.args and isinstance(call.args[0], StreamTaskMessageDelta) |
| 984 | + and isinstance(call.args[0].delta, ToolRequestDelta)) |
| 985 | + ] |
| 986 | + assert len(delta_updates) == len(chunks) |
| 987 | + for update, expected_chunk in zip(delta_updates, chunks): |
| 988 | + assert update.delta.tool_call_id == "call_abc" |
| 989 | + assert update.delta.name == "write_file" |
| 990 | + assert update.delta.arguments_delta == expected_chunk |
| 991 | + |
| 992 | + # 3. A final StreamTaskMessageFull(ToolRequestContent) was streamed with |
| 993 | + # parsed args. |
| 994 | + full_updates = [ |
| 995 | + call.args[0] if call.args else call.kwargs.get("update") |
| 996 | + for call in ctx.stream_update.call_args_list |
| 997 | + if (call.args and isinstance(call.args[0], StreamTaskMessageFull) |
| 998 | + and isinstance(call.args[0].content, ToolRequestContent)) |
| 999 | + ] |
| 1000 | + assert len(full_updates) == 1 |
| 1001 | + final = full_updates[0].content |
| 1002 | + assert final.tool_call_id == "call_abc" |
| 1003 | + assert final.name == "write_file" |
| 1004 | + assert final.arguments == {"path": "/tmp/foo.txt", "contents": "hello world"} |
| 1005 | + |
| 1006 | + @pytest.mark.asyncio |
| 1007 | + async def test_function_call_malformed_args_fall_back_to_empty_dict( |
| 1008 | + self, streaming_model, mock_adk_streaming, _streaming_context_vars, sample_task_id, caplog |
| 1009 | + ): |
| 1010 | + """If the model produces invalid JSON for the args, the final |
| 1011 | + ``ToolRequestContent`` should carry ``arguments={}`` and a warning should |
| 1012 | + be logged. The raw delta stream is preserved either way. |
| 1013 | + """ |
| 1014 | + from agentex.types.task_message_update import StreamTaskMessageFull |
| 1015 | + from agentex.types.task_message_content import ToolRequestContent |
| 1016 | + |
| 1017 | + ctx = self._install_real_task_message(mock_adk_streaming, sample_task_id) |
| 1018 | + |
| 1019 | + # Missing closing brace — invalid JSON. |
| 1020 | + events, _ = self._build_function_call_stream('{"path": "/tmp/foo.txt", "contents":') |
| 1021 | + |
| 1022 | + mock_stream = AsyncMock() |
| 1023 | + mock_stream.__aiter__.return_value = iter(events) |
| 1024 | + streaming_model.client.responses.create = AsyncMock(return_value=mock_stream) |
| 1025 | + |
| 1026 | + with caplog.at_level("WARNING"): |
| 1027 | + await streaming_model.get_response( |
| 1028 | + system_instructions=None, |
| 1029 | + input="please write foo", |
| 1030 | + model_settings=ModelSettings(), |
| 1031 | + tools=[], |
| 1032 | + output_schema=None, |
| 1033 | + handoffs=[], |
| 1034 | + tracing=None, |
| 1035 | + ) |
| 1036 | + |
| 1037 | + full_updates = [ |
| 1038 | + call.args[0] if call.args else call.kwargs.get("update") |
| 1039 | + for call in ctx.stream_update.call_args_list |
| 1040 | + if (call.args and isinstance(call.args[0], StreamTaskMessageFull) |
| 1041 | + and isinstance(call.args[0].content, ToolRequestContent)) |
| 1042 | + ] |
| 1043 | + assert len(full_updates) == 1 |
| 1044 | + assert full_updates[0].content.arguments == {} |
| 1045 | + assert any("Failed to parse tool call arguments" in r.getMessage() for r in caplog.records) |
| 1046 | + |
| 1047 | + |
854 | 1048 | class TestStreamingModelUsageResponseIdAndCacheKey: |
855 | 1049 | """Cover real-Usage capture, real response_id, span emission, and opt-in prompt_cache_key.""" |
856 | 1050 |
|
|
0 commit comments