99import logfire_api
1010from typing_extensions import TypeVar
1111
12- from . import _result , exceptions , messages as _messages , models
12+ from . import _result , _utils , exceptions , messages as _messages , models
1313from .tools import AgentDepsT , RunContext
1414from .usage import Usage , UsageLimits
1515
@@ -160,7 +160,6 @@ async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[Resu
160160 Returns:
161161 An async iterable of the response data.
162162 """
163- self ._stream_response .stream_structured (debounce_by = debounce_by )
164163 async for structured_message , is_last in self .stream_structured (debounce_by = debounce_by ):
165164 result = await self .validate_structured_result (structured_message , allow_partial = not is_last )
166165 yield result
@@ -183,11 +182,11 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None =
183182
184183 with _logfire .span ('response stream text' ) as lf_span :
185184 if delta :
186- async for text in self ._stream_response . stream_text (delta = delta , debounce_by = debounce_by ):
185+ async for text in self ._stream_response_text (delta = delta , debounce_by = debounce_by ):
187186 yield text
188187 else :
189188 combined_validated_text = ''
190- async for text in self ._stream_response . stream_text (delta = delta , debounce_by = debounce_by ):
189+ async for text in self ._stream_response_text (delta = delta , debounce_by = debounce_by ):
191190 combined_validated_text = await self ._validate_text_result (text )
192191 yield combined_validated_text
193192 lf_span .set_attribute ('combined_text' , combined_validated_text )
@@ -214,7 +213,7 @@ async def stream_structured(
214213 yield msg , False
215214 break
216215
217- async for msg in self ._stream_response . stream_structured (debounce_by = debounce_by ):
216+ async for msg in self ._stream_response_structured (debounce_by = debounce_by ):
218217 yield msg , False
219218
220219 msg = self ._stream_response .get ()
@@ -289,6 +288,61 @@ async def _marked_completed(self, message: _messages.ModelResponse) -> None:
289288 self ._all_messages .append (message )
290289 await self ._on_complete ()
291290
291+ async def _stream_response_structured (
292+ self , * , debounce_by : float | None = 0.1
293+ ) -> AsyncIterator [_messages .ModelResponse ]:
294+ async with _utils .group_by_temporal (self ._stream_response , debounce_by ) as group_iter :
295+ async for _items in group_iter :
296+ yield self ._stream_response .get ()
297+
298+ async def _stream_response_text (
299+ self , * , delta : bool = False , debounce_by : float | None = 0.1
300+ ) -> AsyncIterator [str ]:
301+ """Stream the response as an async iterable of text."""
302+
303+ # Define a "merged" version of the iterator that will yield items that have already been retrieved
304+ # and items that we receive while streaming. We define a dedicated async iterator for this so we can
305+ # pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below.
306+ async def _stream_text_deltas_ungrouped () -> AsyncIterator [tuple [str , int ]]:
307+ # yields tuples of (text_content, part_index)
308+ # we don't currently make use of the part_index, but in principle this may be useful
309+ # so we retain it here for now to make possible future refactors simpler
310+ msg = self ._stream_response .get ()
311+ for i , part in enumerate (msg .parts ):
312+ if isinstance (part , _messages .TextPart ) and part .content :
313+ yield part .content , i
314+
315+ async for event in self ._stream_response :
316+ if (
317+ isinstance (event , _messages .PartStartEvent )
318+ and isinstance (event .part , _messages .TextPart )
319+ and event .part .content
320+ ):
321+ yield event .part .content , event .index
322+ elif (
323+ isinstance (event , _messages .PartDeltaEvent )
324+ and isinstance (event .delta , _messages .TextPartDelta )
325+ and event .delta .content_delta
326+ ):
327+ yield event .delta .content_delta , event .index
328+
329+ async def _stream_text_deltas () -> AsyncIterator [str ]:
330+ async with _utils .group_by_temporal (_stream_text_deltas_ungrouped (), debounce_by ) as group_iter :
331+ async for items in group_iter :
332+ # Note: we are currently just dropping the part index on the group here
333+ yield '' .join ([content for content , _ in items ])
334+
335+ if delta :
336+ async for text in _stream_text_deltas ():
337+ yield text
338+ else :
339+ # a quick benchmark shows it's faster to build up a string with concat when we're
340+ # yielding at each step
341+ deltas : list [str ] = []
342+ async for text in _stream_text_deltas ():
343+ deltas .append (text )
344+ yield '' .join (deltas )
345+
292346
293347@dataclass
294348class FinalResult (Generic [ResultDataT ]):
0 commit comments