-
Notifications
You must be signed in to change notification settings - Fork 8
feat(streaming): stream tool call argument deltas in TemporalStreamingModel #355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,6 +1,7 @@ | ||||||||||||||||||||||||||||||
| """Custom Temporal Model Provider with streaming support for OpenAI agents.""" | ||||||||||||||||||||||||||||||
| from __future__ import annotations | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| import json | ||||||||||||||||||||||||||||||
| import uuid | ||||||||||||||||||||||||||||||
| from typing import Any, List, Union, Optional, override | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
@@ -60,9 +61,9 @@ | |||||||||||||||||||||||||||||
| from agentex.lib import adk | ||||||||||||||||||||||||||||||
| from agentex.lib.utils.logging import make_logger | ||||||||||||||||||||||||||||||
| from agentex.lib.core.tracing.tracer import AsyncTracer | ||||||||||||||||||||||||||||||
| from agentex.types.task_message_delta import TextDelta, ReasoningContentDelta, ReasoningSummaryDelta | ||||||||||||||||||||||||||||||
| from agentex.types.task_message_delta import TextDelta, ToolRequestDelta, ReasoningContentDelta, ReasoningSummaryDelta | ||||||||||||||||||||||||||||||
| from agentex.types.task_message_update import StreamTaskMessageFull, StreamTaskMessageDelta | ||||||||||||||||||||||||||||||
| from agentex.types.task_message_content import TextContent, ReasoningContent | ||||||||||||||||||||||||||||||
| from agentex.types.task_message_content import TextContent, ReasoningContent, ToolRequestContent | ||||||||||||||||||||||||||||||
| from agentex.lib.adk.utils._modules.client import create_async_agentex_client | ||||||||||||||||||||||||||||||
| from agentex.lib.core.temporal.plugins.openai_agents.interceptors.context_interceptor import ( | ||||||||||||||||||||||||||||||
| streaming_task_id, | ||||||||||||||||||||||||||||||
|
|
@@ -678,12 +679,27 @@ async def get_response( | |||||||||||||||||||||||||||||
| streaming_mode=self.streaming_mode, | ||||||||||||||||||||||||||||||
| ).__aenter__() | ||||||||||||||||||||||||||||||
| elif item and getattr(item, 'type', None) == 'function_call': | ||||||||||||||||||||||||||||||
| # Track the function call being streamed | ||||||||||||||||||||||||||||||
| # Open a streaming context per function call so argument | ||||||||||||||||||||||||||||||
| # deltas can be published incrementally. Coalescing and | ||||||||||||||||||||||||||||||
| # mode dispatch are handled by the streaming layer. | ||||||||||||||||||||||||||||||
| call_id = getattr(item, 'call_id', '') | ||||||||||||||||||||||||||||||
| tool_name = getattr(item, 'name', '') | ||||||||||||||||||||||||||||||
| call_context = await adk.streaming.streaming_task_message_context( | ||||||||||||||||||||||||||||||
| task_id=task_id, | ||||||||||||||||||||||||||||||
| initial_content=ToolRequestContent( | ||||||||||||||||||||||||||||||
| author="agent", | ||||||||||||||||||||||||||||||
| tool_call_id=call_id, | ||||||||||||||||||||||||||||||
| name=tool_name, | ||||||||||||||||||||||||||||||
| arguments={}, | ||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||
| streaming_mode=self.streaming_mode, | ||||||||||||||||||||||||||||||
| ).__aenter__() | ||||||||||||||||||||||||||||||
| function_calls_in_progress[output_index] = { | ||||||||||||||||||||||||||||||
| 'id': getattr(item, 'id', ''), | ||||||||||||||||||||||||||||||
| 'call_id': getattr(item, 'call_id', ''), | ||||||||||||||||||||||||||||||
| 'name': getattr(item, 'name', ''), | ||||||||||||||||||||||||||||||
| 'call_id': call_id, | ||||||||||||||||||||||||||||||
| 'name': tool_name, | ||||||||||||||||||||||||||||||
| 'arguments': getattr(item, 'arguments', ''), | ||||||||||||||||||||||||||||||
| 'context': call_context, | ||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||
| logger.debug(f"[TemporalStreamingModel] Starting function call: {item.name}") | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
@@ -704,8 +720,24 @@ async def get_response( | |||||||||||||||||||||||||||||
| output_index = getattr(event, 'output_index', 0) | ||||||||||||||||||||||||||||||
| delta = getattr(event, 'delta', '') | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if output_index in function_calls_in_progress: | ||||||||||||||||||||||||||||||
| function_calls_in_progress[output_index]['arguments'] += delta | ||||||||||||||||||||||||||||||
| call_data = function_calls_in_progress.get(output_index) | ||||||||||||||||||||||||||||||
| if call_data is not None: | ||||||||||||||||||||||||||||||
| call_data['arguments'] += delta | ||||||||||||||||||||||||||||||
| call_context = call_data.get('context') | ||||||||||||||||||||||||||||||
| if call_context is not None: | ||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||
| await call_context.stream_update(StreamTaskMessageDelta( | ||||||||||||||||||||||||||||||
| parent_task_message=call_context.task_message, | ||||||||||||||||||||||||||||||
| delta=ToolRequestDelta( | ||||||||||||||||||||||||||||||
| tool_call_id=call_data['call_id'], | ||||||||||||||||||||||||||||||
| name=call_data['name'], | ||||||||||||||||||||||||||||||
| arguments_delta=delta, | ||||||||||||||||||||||||||||||
| type="tool_request", | ||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||
| type="delta", | ||||||||||||||||||||||||||||||
| )) | ||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||
| logger.warning(f"Failed to send tool request delta: {e}") | ||||||||||||||||||||||||||||||
| logger.debug(f"[TemporalStreamingModel] Function call args delta: {delta[:50]}...") | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| elif isinstance(event, ResponseFunctionCallArgumentsDoneEvent): | ||||||||||||||||||||||||||||||
|
|
@@ -830,6 +862,40 @@ async def get_response( | |||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| output_items.append(tool_call) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Emit the final ToolRequestContent and close the | ||||||||||||||||||||||||||||||
| # per-call streaming context. If the model produced | ||||||||||||||||||||||||||||||
| # invalid JSON args (truncation, hallucination), fall | ||||||||||||||||||||||||||||||
| # back to an empty dict so the streaming layer can | ||||||||||||||||||||||||||||||
| # still persist a message. | ||||||||||||||||||||||||||||||
| call_context = call_data.get('context') | ||||||||||||||||||||||||||||||
| if call_context is not None: | ||||||||||||||||||||||||||||||
| raw_args = call_data['arguments'] or '' | ||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||
| parsed_args = json.loads(raw_args) if raw_args else {} | ||||||||||||||||||||||||||||||
| except json.JSONDecodeError: | ||||||||||||||||||||||||||||||
| logger.warning( | ||||||||||||||||||||||||||||||
| f"Failed to parse tool call arguments for {call_data['name']} " | ||||||||||||||||||||||||||||||
| f"(raw_args_bytes={len(raw_args)})" | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| parsed_args = {} | ||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||
| await call_context.stream_update(StreamTaskMessageFull( | ||||||||||||||||||||||||||||||
| parent_task_message=call_context.task_message, | ||||||||||||||||||||||||||||||
| content=ToolRequestContent( | ||||||||||||||||||||||||||||||
| author="agent", | ||||||||||||||||||||||||||||||
| tool_call_id=call_data['call_id'], | ||||||||||||||||||||||||||||||
| name=call_data['name'], | ||||||||||||||||||||||||||||||
| arguments=parsed_args, | ||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||
| type="full", | ||||||||||||||||||||||||||||||
| )) | ||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||
| logger.warning(f"Failed to send tool request full update: {e}") | ||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||
| await call_context.close() | ||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||
| logger.warning(f"Failed to close tool request context: {e}") | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| elif isinstance(event, ResponseReasoningSummaryPartAddedEvent): | ||||||||||||||||||||||||||||||
|
Comment on lines
+894
to
899
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Prompt To Fix With AIThis is a comment left during a code review.
Path: src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py
Line: 893-898
Comment:
**Double-close of streaming context on every successful function call** — after `ResponseOutputItemDoneEvent` closes the context (line 894), `call_data['context']` is left non-`None`. The orphan-cleanup loop (lines 933–940) then iterates `function_calls_in_progress.values()` and closes every entry whose `'context'` is not `None` — which includes every context that was already closed normally. `close()` will be called twice for every function call that finished cleanly, potentially publishing a duplicate "stream ended" event to Redis for each tool call in the response.
```suggestion
try:
await call_context.close()
except Exception as e:
logger.warning(f"Failed to close tool request context: {e}")
finally:
call_data['context'] = None
elif isinstance(event, ResponseReasoningSummaryPartAddedEvent):
```
How can I resolve this? If you propose a fix, please make it concise. |
||||||||||||||||||||||||||||||
| # New reasoning part/summary started - reset accumulator | ||||||||||||||||||||||||||||||
| part = getattr(event, 'part', None) | ||||||||||||||||||||||||||||||
|
|
@@ -863,6 +929,17 @@ async def get_response( | |||||||||||||||||||||||||||||
| await streaming_context.close() | ||||||||||||||||||||||||||||||
| streaming_context = None | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Defensive: close any function call contexts that didn't see a | ||||||||||||||||||||||||||||||
| # ResponseOutputItemDoneEvent (truncated stream, error mid-call). | ||||||||||||||||||||||||||||||
| for call_data in function_calls_in_progress.values(): | ||||||||||||||||||||||||||||||
| call_context = call_data.get('context') | ||||||||||||||||||||||||||||||
| if call_context is not None: | ||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||
| await call_context.close() | ||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||
| logger.warning(f"Failed to close orphaned tool request context: {e}") | ||||||||||||||||||||||||||||||
| call_data['context'] = None | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Build the response from output items collected during streaming | ||||||||||||||||||||||||||||||
| # Create output from the items we collected | ||||||||||||||||||||||||||||||
| response_output = [] | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.