22
33import asyncio
44import dataclasses
5+ import uuid
56from abc import ABC
67from collections .abc import AsyncIterator , Iterator , Sequence
78from contextlib import asynccontextmanager , contextmanager
@@ -267,43 +268,93 @@ async def run(
267268
268269@dataclasses .dataclass
269270class HandleResponseNode (BaseNode [GraphAgentState , GraphAgentDeps [DepsT , Any ], NodeRunEndT ]):
270- """Process e response from a model, decide whether to end the run or make a new request."""
271+ """Process the response from a model, decide whether to end the run or make a new request."""
271272
272273 model_response : _messages .ModelResponse
273274
275+ _stream : AsyncIterator [_messages .HandleResponseEvent ] | None = field (default = None , repr = False )
276+ _next_node : ModelRequestNode [DepsT , NodeRunEndT ] | FinalResultNode [DepsT , NodeRunEndT ] | None = field (
277+ default = None , repr = False
278+ )
279+ _tool_responses : list [_messages .ModelRequestPart ] = field (default_factory = list , repr = False )
280+
274281 async def run (
275282 self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]]
276283 ) -> Union [ModelRequestNode [DepsT , NodeRunEndT ], FinalResultNode [DepsT , NodeRunEndT ]]: # noqa UP007
284+ async with self .run_stream (ctx ):
285+ pass
286+
287+ # the stream should set `self._next_node` before it ends:
288+ assert (next_node := self ._next_node ) is not None
289+ return next_node
290+
291+ @asynccontextmanager
292+ async def run_stream (
293+ self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , Any ]]
294+ ) -> AsyncIterator [AsyncIterator [_messages .HandleResponseEvent ]]:
277295 with _logfire .span ('handle model response' , run_step = ctx .state .run_step ) as handle_span :
278- texts : list [str ] = []
279- tool_calls : list [_messages .ToolCallPart ] = []
280- for part in self .model_response .parts :
281- if isinstance (part , _messages .TextPart ):
282- # ignore empty content for text parts, see #437
283- if part .content :
284- texts .append (part .content )
285- elif isinstance (part , _messages .ToolCallPart ):
286- tool_calls .append (part )
296+ stream = self ._run_stream (ctx )
297+ yield stream
298+
299+ # Run the stream to completion if it was not finished:
300+ async for _event in stream :
301+ pass
302+
303+ # Set the next node based on the final state of the stream
304+ next_node = self ._next_node
305+ if isinstance (next_node , FinalResultNode ):
306+ handle_span .set_attribute ('result' , next_node .data )
307+ handle_span .message = 'handle model response -> final result'
308+ elif tool_responses := self ._tool_responses :
309+ # TODO: We could drop `self._tool_responses` if we drop this set_attribute
310+ # I'm thinking it might be better to just create a span for the handling of each tool
311+ # than to set an attribute here.
312+ handle_span .set_attribute ('tool_responses' , tool_responses )
313+ tool_responses_str = ' ' .join (r .part_kind for r in tool_responses )
314+ handle_span .message = f'handle model response -> { tool_responses_str } '
315+
316+ async def _run_stream (
317+ self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , Any ]]
318+ ) -> AsyncIterator [_messages .HandleResponseEvent ]:
319+ if self ._stream is None :
320+ # Ensure that the stream is only run once
321+
322+ async def _run_stream () -> AsyncIterator [_messages .HandleResponseEvent ]:
323+ texts : list [str ] = []
324+ tool_calls : list [_messages .ToolCallPart ] = []
325+ for part in self .model_response .parts :
326+ if isinstance (part , _messages .TextPart ):
327+ # ignore empty content for text parts, see #437
328+ if part .content :
329+ texts .append (part .content )
330+ elif isinstance (part , _messages .ToolCallPart ):
331+ tool_calls .append (part )
332+ else :
333+ assert_never (part )
334+
335+ # At the moment, we prioritize at least executing tool calls if they are present.
336+ # In the future, we'd consider making this configurable at the agent or run level.
337+ # This accounts for cases like anthropic returns that might contain a text response
338+ # and a tool call response, where the text response just indicates the tool call will happen.
339+ if tool_calls :
340+ async for event in self ._handle_tool_calls (ctx , tool_calls ):
341+ yield event
342+ elif texts :
343+ # No events are emitted during the handling of text responses, so we don't need to yield anything
344+ self ._next_node = await self ._handle_text_response (ctx , texts )
287345 else :
288- assert_never (part )
289-
290- # At the moment, we prioritize at least executing tool calls if they are present.
291- # In the future, we'd consider making this configurable at the agent or run level.
292- # This accounts for cases like anthropic returns that might contain a text response
293- # and a tool call response, where the text response just indicates the tool call will happen.
294- if tool_calls :
295- return await self ._handle_tool_calls_response (ctx , tool_calls , handle_span )
296- elif texts :
297- return await self ._handle_text_response (ctx , texts , handle_span )
298- else :
299- raise exceptions .UnexpectedModelBehavior ('Received empty model response' )
346+ raise exceptions .UnexpectedModelBehavior ('Received empty model response' )
347+
348+ self ._stream = _run_stream ()
349+
350+ async for event in self ._stream :
351+ yield event
300352
301- async def _handle_tool_calls_response (
353+ async def _handle_tool_calls (
302354 self ,
303355 ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]],
304356 tool_calls : list [_messages .ToolCallPart ],
305- handle_span : logfire_api .LogfireSpan ,
306- ):
357+ ) -> AsyncIterator [_messages .HandleResponseEvent ]:
307358 result_schema = ctx .deps .result_schema
308359
309360 # first look for the result tool call
@@ -324,26 +375,24 @@ async def _handle_tool_calls_response(
324375 final_result = MarkFinalResult (result_data , call .tool_name )
325376
326377 # Then build the other request parts based on end strategy
327- tool_responses = await _process_function_tools (tool_calls , final_result and final_result .tool_name , ctx )
378+ tool_responses : list [_messages .ModelRequestPart ] = self ._tool_responses
379+ async for event in _process_function_tools (
380+ tool_calls , final_result and final_result .tool_name , ctx , tool_responses
381+ ):
382+ yield event
328383
329384 if final_result :
330- handle_span .set_attribute ('result' , final_result .data )
331- handle_span .message = 'handle model response -> final result'
332- return FinalResultNode [DepsT , NodeRunEndT ](final_result , tool_responses )
385+ self ._next_node = FinalResultNode [DepsT , NodeRunEndT ](final_result , tool_responses )
333386 else :
334387 if tool_responses :
335- handle_span .set_attribute ('tool_responses' , tool_responses )
336- tool_responses_str = ' ' .join (r .part_kind for r in tool_responses )
337- handle_span .message = f'handle model response -> { tool_responses_str } '
338388 parts .extend (tool_responses )
339- return ModelRequestNode [DepsT , NodeRunEndT ](_messages .ModelRequest (parts = parts ))
389+ self . _next_node = ModelRequestNode [DepsT , NodeRunEndT ](_messages .ModelRequest (parts = parts ))
340390
341391 async def _handle_text_response (
342392 self ,
343393 ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]],
344394 texts : list [str ],
345- handle_span : logfire_api .LogfireSpan ,
346- ):
395+ ) -> ModelRequestNode [DepsT , NodeRunEndT ] | FinalResultNode [DepsT , NodeRunEndT ]:
347396 result_schema = ctx .deps .result_schema
348397
349398 text = '\n \n ' .join (texts )
@@ -355,8 +404,6 @@ async def _handle_text_response(
355404 ctx .state .increment_retries (ctx .deps .max_result_retries )
356405 return ModelRequestNode [DepsT , NodeRunEndT ](_messages .ModelRequest (parts = [e .tool_retry ]))
357406 else :
358- handle_span .set_attribute ('result' , result_data )
359- handle_span .message = 'handle model response -> final result'
360407 return FinalResultNode [DepsT , NodeRunEndT ](MarkFinalResult (result_data , None ))
361408 else :
362409 ctx .state .increment_retries (ctx .deps .max_result_retries )
@@ -560,11 +607,15 @@ async def on_complete():
560607 last_message = messages [- 1 ]
561608 assert isinstance (last_message , _messages .ModelResponse )
562609 tool_calls = [part for part in last_message .parts if isinstance (part , _messages .ToolCallPart )]
563- parts = await _process_function_tools (
610+
611+ parts : list [_messages .ModelRequestPart ] = []
612+ async for _event in _process_function_tools (
564613 tool_calls ,
565614 result_tool_name ,
566615 ctx ,
567- )
616+ parts ,
617+ ):
618+ pass
568619 # TODO: Should we do something here related to the retry count?
569620 # Maybe we should move the incrementing of the retry count to where we actually make a request?
570621 # if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
@@ -590,25 +641,27 @@ async def _process_function_tools(
590641 tool_calls : list [_messages .ToolCallPart ],
591642 result_tool_name : str | None ,
592643 ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]],
593- ) -> list [_messages .ModelRequestPart ]:
644+ output_parts : list [_messages .ModelRequestPart ],
645+ ) -> AsyncIterator [_messages .HandleResponseEvent ]:
594646 """Process function (non-result) tool calls in parallel.
595647
596648 Also add stub return parts for any other tools that need it.
597- """
598- parts : list [_messages .ModelRequestPart ] = []
599- tasks : list [asyncio .Task [_messages .ToolReturnPart | _messages .RetryPromptPart ]] = []
600649
650+ Because async iterators can't have return values, we use `parts` as an output argument.
651+ """
601652 stub_function_tools = bool (result_tool_name ) and ctx .deps .end_strategy == 'early'
602653 result_schema = ctx .deps .result_schema
603654
604655 # we rely on the fact that if we found a result, it's the first result tool in the last
605656 found_used_result_tool = False
606657 run_context = _build_run_context (ctx )
607658
659+ calls_to_run : list [tuple [Tool [DepsT ], _messages .ToolCallPart ]] = []
660+ call_index_to_event_id : dict [int , uuid .UUID ] = {}
608661 for call in tool_calls :
609662 if call .tool_name == result_tool_name and not found_used_result_tool :
610663 found_used_result_tool = True
611- parts .append (
664+ output_parts .append (
612665 _messages .ToolReturnPart (
613666 tool_name = call .tool_name ,
614667 content = 'Final result processed.' ,
@@ -617,41 +670,55 @@ async def _process_function_tools(
617670 )
618671 elif tool := ctx .deps .function_tools .get (call .tool_name ):
619672 if stub_function_tools :
620- parts .append (
673+ output_parts .append (
621674 _messages .ToolReturnPart (
622675 tool_name = call .tool_name ,
623676 content = 'Tool not executed - a final result was already processed.' ,
624677 tool_call_id = call .tool_call_id ,
625678 )
626679 )
627680 else :
628- tasks .append (asyncio .create_task (tool .run (call , run_context ), name = call .tool_name ))
681+ event = _messages .FunctionToolCallEvent (call )
682+ yield event
683+ call_index_to_event_id [len (calls_to_run )] = event .call_id
684+ calls_to_run .append ((tool , call ))
629685 elif result_schema is not None and call .tool_name in result_schema .tools :
630686 # if tool_name is in _result_schema, it means we found a result tool but an error occurred in
631687 # validation, we don't add another part here
632688 if result_tool_name is not None :
633- parts .append (
634- _messages .ToolReturnPart (
635- tool_name = call .tool_name ,
636- content = 'Result tool not used - a final result was already processed.' ,
637- tool_call_id = call .tool_call_id ,
638- )
689+ part = _messages .ToolReturnPart (
690+ tool_name = call .tool_name ,
691+ content = 'Result tool not used - a final result was already processed.' ,
692+ tool_call_id = call .tool_call_id ,
639693 )
694+ output_parts .append (part )
640695 else :
641- parts .append (_unknown_tool (call .tool_name , ctx ))
696+ output_parts .append (_unknown_tool (call .tool_name , ctx ))
697+
698+ if not calls_to_run :
699+ return
642700
643701 # Run all tool tasks in parallel
644- if tasks :
645- with _logfire .span ('running {tools=}' , tools = [t .get_name () for t in tasks ]):
646- task_results : Sequence [_messages .ToolReturnPart | _messages .RetryPromptPart ] = await asyncio .gather (* tasks )
647- for result in task_results :
648- if isinstance (result , _messages .ToolReturnPart ):
649- parts .append (result )
650- elif isinstance (result , _messages .RetryPromptPart ):
651- parts .append (result )
702+ results_by_index : dict [int , _messages .ModelRequestPart ] = {}
703+ with _logfire .span ('running {tools=}' , tools = [call .tool_name for _ , call in calls_to_run ]):
704+ # TODO: Should we wrap each individual tool call in a dedicated span?
705+ tasks = [asyncio .create_task (tool .run (call , run_context ), name = call .tool_name ) for tool , call in calls_to_run ]
706+ pending = tasks
707+ while pending :
708+ done , pending = await asyncio .wait (pending , return_when = asyncio .FIRST_COMPLETED )
709+ for task in done :
710+ index = tasks .index (task )
711+ result = task .result ()
712+ yield _messages .FunctionToolResultEvent (result , call_id = call_index_to_event_id [index ])
713+ if isinstance (result , (_messages .ToolReturnPart , _messages .RetryPromptPart )):
714+ results_by_index [index ] = result
652715 else :
653716 assert_never (result )
654- return parts
717+
718+ # We append the results at the end, rather than as they are received, to retain a consistent ordering
719+ # This is mostly just to simplify testing
720+ for k in sorted (results_by_index ):
721+ output_parts .append (results_by_index [k ])
655722
656723
657724def _unknown_tool (
0 commit comments