diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index bd0037bdcb..60df450521 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -692,8 +692,19 @@ def get_author_for_event(llm_response): else: return invocation_context.agent.name + # Track whether the model has produced spoken/text content across + # receive() cycles. After the first content-bearing cycle, re-entering + # the while-True loop is only valid if a function response was sent + # back to the model (the model needs to process the tool result). + # Without this guard, orphaned function responses from fire-and-forget + # tool calls land in the next receive() cycle as fresh input, causing + # the model to generate a complete duplicate response. + _has_yielded_content = False + try: while True: + _cycle_had_function_response = False + async with Aclosing(llm_connection.receive()) as agen: async for llm_response in agen: if llm_response.live_session_resumption_update: @@ -738,7 +749,30 @@ def get_author_for_event(llm_response): invocation_context, audio_blob, cache_type='output' ) + # Track content and function responses for loop + # control. Content (audio/text) means the model has + # spoken. Function responses mean the caller sent a + # tool result back and the model may respond to it. + if event.content and event.content.parts: + for part in event.content.parts: + if part.inline_data or part.text: + _has_yielded_content = True + if part.function_response: + _cycle_had_function_response = True + yield event + + # Prevent orphaned re-entry after content delivery. + # + # If the model has already produced content and this cycle + # did NOT process a function response, there is nothing + # pending for the model to respond to. Re-entering receive() + # would consume any orphaned function response as fresh + # input, causing a duplicate response. See: + # https://github.com/google/adk-python/issues/4902 + if _has_yielded_content and not _cycle_had_function_response: + break + # Give opportunity for other tasks to run. await asyncio.sleep(0) except ConnectionClosedOK: diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow.py b/tests/unittests/flows/llm_flows/test_base_llm_flow.py index 3dfadbcabf..6ed7d10a72 100644 --- a/tests/unittests/flows/llm_flows/test_base_llm_flow.py +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow.py @@ -487,3 +487,238 @@ async def call(self, **kwargs): assert result1.grounding_metadata == {'foo': 'bar'} assert result2.grounding_metadata == {'foo': 'bar'} assert result3.grounding_metadata == {'foo': 'bar'} + + +# --------------------------------------------------------------------------- +# Tests for _receive_from_model loop control (issue #4902) +# --------------------------------------------------------------------------- + + +class _MultiCycleMockConnection: + """Mock connection that yields different responses per receive() call. + + Each call to receive() returns the next sequence from `cycles`. + After all cycles are exhausted, subsequent calls yield nothing. + """ + + def __init__(self, cycles: list[list[LlmResponse]]): + self._cycles = cycles + self._call_count = 0 + + async def send_history(self, history): + pass + + async def send_content(self, content): + pass + + async def receive(self): + idx = self._call_count + self._call_count += 1 + if idx < len(self._cycles): + for resp in self._cycles[idx]: + yield resp + + +@pytest.mark.asyncio +async def test_receive_from_model_breaks_after_content_no_fn_response(): + """After yielding content, loop should break if no function response. + + This prevents orphaned function responses from triggering duplicate + model responses (issue #4902). + """ + flow = BaseLlmFlowForTesting() + agent = Agent(name='test_agent', model='mock') + invocation_context = await testing_utils.create_invocation_context( + agent=agent, + user_content='Hello', + ) + + # Cycle 1: audio content + turnComplete (model spoke) + # Cycle 2: should NOT be reached (loop should break after cycle 1) + connection = _MultiCycleMockConnection( + cycles=[ + [ + LlmResponse( + content=types.Content( + role='model', + parts=[ + types.Part( + inline_data=types.Blob( + data=b'\x00', + mime_type='audio/pcm', + ), + ) + ], + ), + ), + LlmResponse(turn_complete=True), + ], + [ + LlmResponse( + content=types.Content( + role='model', + parts=[types.Part.from_text(text='DUPLICATE')], + ), + ), + LlmResponse(turn_complete=True), + ], + ] + ) + + events = [] + async for event in flow._receive_from_model( + connection, + 'test', + invocation_context, + LlmRequest(), + ): + events.append(event) + + # Should have events from cycle 1 only, not the DUPLICATE from cycle 2 + texts = [ + p.text + for e in events + if e.content and e.content.parts + for p in e.content.parts + if p.text + ] + assert 'DUPLICATE' not in texts, ( + 'Loop re-entered after content delivery, producing a duplicate.' + ' _receive_from_model should break after yielding content when' + ' no function response is pending.' + ) + assert connection._call_count == 1, ( + f'Expected 1 receive() call, got {connection._call_count}.' + ' Loop should break after content + turnComplete.' + ) + + +@pytest.mark.asyncio +async def test_receive_from_model_continues_when_fn_response_pending(): + """Loop should continue if a function response was processed. + + When the model calls a tool and the caller sends the response back, + the model is expected to produce a follow-up response. The loop + must re-enter receive() to collect it. + """ + flow = BaseLlmFlowForTesting() + agent = Agent(name='test_agent', model='mock') + invocation_context = await testing_utils.create_invocation_context( + agent=agent, + user_content='Hello', + ) + + # Cycle 1: function response (tool result sent back) + turnComplete + # Cycle 2: model speaks in response to tool result + turnComplete + connection = _MultiCycleMockConnection( + cycles=[ + [ + LlmResponse( + content=types.Content( + role='model', + parts=[ + types.Part( + function_response=types.FunctionResponse( + name='my_tool', + response={'result': 'ok'}, + ), + ) + ], + ), + ), + LlmResponse(turn_complete=True), + ], + [ + LlmResponse( + content=types.Content( + role='model', + parts=[types.Part.from_text(text='Tool result spoken')], + ), + ), + LlmResponse(turn_complete=True), + ], + ] + ) + + events = [] + async for event in flow._receive_from_model( + connection, + 'test', + invocation_context, + LlmRequest(), + ): + events.append(event) + + texts = [ + p.text + for e in events + if e.content and e.content.parts + for p in e.content.parts + if p.text + ] + assert 'Tool result spoken' in texts, ( + 'Loop should continue to cycle 2 when a function response was' + ' pending in cycle 1.' + ) + assert ( + connection._call_count == 2 + ), f'Expected 2 receive() calls, got {connection._call_count}.' + + +@pytest.mark.asyncio +async def test_receive_from_model_continues_on_empty_cycle(): + """Loop should continue on empty cycles (NR retry). + + When the model sends turnComplete without any content or function + responses, the loop should re-enter to allow implicit retry. This + preserves the existing behaviour for the no-response pattern. + """ + flow = BaseLlmFlowForTesting() + agent = Agent(name='test_agent', model='mock') + invocation_context = await testing_utils.create_invocation_context( + agent=agent, + user_content='Hello', + ) + + # Cycle 1: empty turnComplete (NR pattern) + # Cycle 2: model speaks on retry + connection = _MultiCycleMockConnection( + cycles=[ + [ + LlmResponse(turn_complete=True), + ], + [ + LlmResponse( + content=types.Content( + role='model', + parts=[types.Part.from_text(text='Retry worked')], + ), + ), + LlmResponse(turn_complete=True), + ], + ] + ) + + events = [] + async for event in flow._receive_from_model( + connection, + 'test', + invocation_context, + LlmRequest(), + ): + events.append(event) + + texts = [ + p.text + for e in events + if e.content and e.content.parts + for p in e.content.parts + if p.text + ] + assert 'Retry worked' in texts, ( + 'Loop should continue past empty cycle (NR retry) and yield' + ' content from cycle 2.' + ) + assert ( + connection._call_count == 2 + ), f'Expected 2 receive() calls, got {connection._call_count}.'