Skip to content

Commit 8da61be

Browse files
hangfeicopybara-github
authored andcommitted
fix: Flush pending transcriptions on turn/generation complete or interrupt for Gemini API
The Gemini API may not always send an explicit transcription finished signal. This change ensures that any buffered input or output transcription text is yielded as a finished transcription when a turn is completed, generation is complete, or the session is interrupted. Also, refined the check for `event.partial` in runners.py to be more explicit. Co-authored-by: Hangfei Lin <[email protected]> PiperOrigin-RevId: 839008606
1 parent 98d8293 commit 8da61be

File tree

5 files changed

+359
-7
lines changed

5 files changed

+359
-7
lines changed

contributing/samples/live_bidi_streaming_single_agent/agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ async def check_prime(nums: list[int]) -> str:
6565

6666

6767
root_agent = Agent(
68-
# model='gemini-live-2.5-flash-preview-native-audio-09-2025', # vertex
69-
model='gemini-2.5-flash-native-audio-preview-09-2025', # for AI studio
68+
model='gemini-live-2.5-flash-preview-native-audio-09-2025', # vertex
69+
# model='gemini-2.5-flash-native-audio-preview-09-2025', # for AI studio
7070
# key
7171
name='roll_dice_agent',
7272
description=(

src/google/adk/models/gemini_llm_connection.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from google.genai import types
2222

2323
from ..utils.context_utils import Aclosing
24+
from ..utils.variant_utils import GoogleLLMVariant
2425
from .base_llm_connection import BaseLlmConnection
2526
from .llm_response import LlmResponse
2627

@@ -36,10 +37,15 @@
3637
class GeminiLlmConnection(BaseLlmConnection):
3738
"""The Gemini model connection."""
3839

39-
def __init__(self, gemini_session: live.AsyncSession):
40+
def __init__(
41+
self,
42+
gemini_session: live.AsyncSession,
43+
api_backend: GoogleLLMVariant = GoogleLLMVariant.VERTEX_AI,
44+
):
4045
self._gemini_session = gemini_session
4146
self._input_transcription_text: str = ''
4247
self._output_transcription_text: str = ''
48+
self._api_backend = api_backend
4349

4450
async def send_history(self, history: list[types.Content]):
4551
"""Sends the conversation history to the gemini model.
@@ -171,6 +177,9 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
171177
yield self.__build_full_text_response(text)
172178
text = ''
173179
yield llm_response
180+
# Note: in some cases, tool_call may arrive before
181+
# generation_complete, causing transcription to appear after
182+
# tool_call in the session log.
174183
if message.server_content.input_transcription:
175184
if message.server_content.input_transcription.text:
176185
self._input_transcription_text += (
@@ -215,6 +224,32 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
215224
partial=False,
216225
)
217226
self._output_transcription_text = ''
227+
# The Gemini API might not send a transcription finished signal.
228+
# Instead, we rely on generation_complete, turn_complete or
229+
# interrupted signals to flush any pending transcriptions.
230+
if self._api_backend == GoogleLLMVariant.GEMINI_API and (
231+
message.server_content.interrupted
232+
or message.server_content.turn_complete
233+
or message.server_content.generation_complete
234+
):
235+
if self._input_transcription_text:
236+
yield LlmResponse(
237+
input_transcription=types.Transcription(
238+
text=self._input_transcription_text,
239+
finished=True,
240+
),
241+
partial=False,
242+
)
243+
self._input_transcription_text = ''
244+
if self._output_transcription_text:
245+
yield LlmResponse(
246+
output_transcription=types.Transcription(
247+
text=self._output_transcription_text,
248+
finished=True,
249+
),
250+
partial=False,
251+
)
252+
self._output_transcription_text = ''
218253
if message.server_content.turn_complete:
219254
if text:
220255
yield self.__build_full_text_response(text)

src/google/adk/models/google_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
342342
async with self._live_api_client.aio.live.connect(
343343
model=llm_request.model, config=llm_request.live_connect_config
344344
) as live_session:
345-
yield GeminiLlmConnection(live_session)
345+
yield GeminiLlmConnection(live_session, api_backend=self._api_backend)
346346

347347
async def _adapt_computer_use_tool(self, llm_request: LlmRequest) -> None:
348348
"""Adapt the google computer use predefined functions to the adk computer use toolset."""

src/google/adk/runners.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,23 @@
6767
logger = logging.getLogger('google_adk.' + __name__)
6868

6969

70+
def _is_tool_call_or_response(event: Event) -> bool:
71+
return bool(event.get_function_calls() or event.get_function_responses())
72+
73+
74+
def _is_transcription(event: Event) -> bool:
75+
return (
76+
event.input_transcription is not None
77+
or event.output_transcription is not None
78+
)
79+
80+
81+
def _has_non_empty_transcription_text(transcription) -> bool:
82+
return bool(
83+
transcription and transcription.text and transcription.text.strip()
84+
)
85+
86+
7087
class Runner:
7188
"""The Runner class is used to run agents.
7289
@@ -626,6 +643,7 @@ async def _exec_with_plugin(
626643
invocation_context: The invocation context
627644
session: The current session
628645
execute_fn: A callable that returns an AsyncGenerator of Events
646+
is_live_call: Whether this is a live call
629647
630648
Yields:
631649
Events from the execution, including any generated by plugins
@@ -651,13 +669,74 @@ async def _exec_with_plugin(
651669
yield early_exit_event
652670
else:
653671
# Step 2: Otherwise continue with normal execution
672+
# Note for live/bidi:
673+
# the transcription may arrive later then the action(function call
674+
# event and thus function response event). In this case, the order of
675+
# transcription and function call event will be wrong if we just
676+
# append as it arrives. To address this, we should check if there is
677+
# transcription going on. If there is transcription going on, we
678+
# should hold on appending the function call event until the
679+
# transcription is finished. The transcription in progress can be
680+
# identified by checking if the transcription event is partial. When
681+
# the next transcription event is not partial, it means the previous
682+
# transcription is finished. Then if there is any buffered function
683+
# call event, we should append them after this finished(non-parital)
684+
# transcription event.
685+
buffered_events: list[Event] = []
686+
is_transcribing: bool = False
687+
654688
async with Aclosing(execute_fn(invocation_context)) as agen:
655689
async for event in agen:
656-
if not event.partial:
657-
if self._should_append_event(event, is_live_call):
690+
if is_live_call:
691+
if event.partial and _is_transcription(event):
692+
is_transcribing = True
693+
if is_transcribing and _is_tool_call_or_response(event):
694+
# only buffer function call and function response event which is
695+
# non-partial
696+
buffered_events.append(event)
697+
continue
698+
# Note for live/bidi: for audio response, it's considered as
699+
# non-paritla event(event.partial=None)
700+
# event.partial=False and event.partial=None are considered as
701+
# non-partial event; event.partial=True is considered as partial
702+
# event.
703+
if event.partial is not True:
704+
if _is_transcription(event) and (
705+
_has_non_empty_transcription_text(event.input_transcription)
706+
or _has_non_empty_transcription_text(
707+
event.output_transcription
708+
)
709+
):
710+
# transcription end signal, append buffered events
711+
is_transcribing = False
712+
logger.debug(
713+
'Appending transcription finished event: %s', event
714+
)
715+
if self._should_append_event(event, is_live_call):
716+
await self.session_service.append_event(
717+
session=session, event=event
718+
)
719+
720+
for buffered_event in buffered_events:
721+
logger.debug('Appending buffered event: %s', buffered_event)
722+
await self.session_service.append_event(
723+
session=session, event=buffered_event
724+
)
725+
buffered_events = []
726+
else:
727+
# non-transcription event or empty transcription event, for
728+
# example, event that stores blob reference, should be appended.
729+
if self._should_append_event(event, is_live_call):
730+
logger.debug('Appending non-buffered event: %s', event)
731+
await self.session_service.append_event(
732+
session=session, event=event
733+
)
734+
else:
735+
if event.partial is not True:
658736
await self.session_service.append_event(
659737
session=session, event=event
660738
)
739+
661740
# Step 3: Run the on_event callbacks to optionally modify the event.
662741
modified_event = await plugin_manager.run_on_event_callback(
663742
invocation_context=invocation_context, event=event

0 commit comments

Comments
 (0)