Skip to content

fix: Save output in state via output_key only when the event is authored by current agent #1680

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/google/adk/agents/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,12 +431,22 @@ def _llm_flow(self) -> BaseLlmFlow:

def __maybe_save_output_to_state(self, event: Event):
"""Saves the model output to state if needed."""
# skip if the event was authored by some other agent (e.g. current agent
# transferred to another agent)
if event.author != self.name:
logger.debug(
'Skipping output save for agent %s: event authored by %s',
self.name,
event.author,
)
return
if (
self.output_key
and event.is_final_response()
and event.content
and event.content.parts
):

result = ''.join(
[part.text if part.text else '' for part in event.content.parts]
)
Expand Down
210 changes: 210 additions & 0 deletions tests/unittests/agents/test_llm_agent_output_save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for LlmAgent output saving functionality."""

from unittest.mock import Mock
from unittest.mock import patch

from google.adk.agents.llm_agent import LlmAgent
from google.adk.events.event import Event
from google.adk.events.event_actions import EventActions
from google.genai import types
from pydantic import BaseModel
import pytest


class MockOutputSchema(BaseModel):
message: str
confidence: float


def create_test_event(
author: str = "test_agent",
content_text: str = "Hello world",
is_final: bool = True,
invocation_id: str = "test_invocation",
) -> Event:
"""Helper to create test events."""
# Create mock content
parts = [types.Part.from_text(text=content_text)] if content_text else []
content = types.Content(role="model", parts=parts) if parts else None

# Create event
event = Event(
invocation_id=invocation_id,
author=author,
content=content,
actions=EventActions(),
)

# Mock is_final_response if needed
if not is_final:
event.partial = True

return event


class TestLlmAgentOutputSave:
"""Test suite for LlmAgent output saving functionality."""

def test_maybe_save_output_to_state_skips_different_author(self, caplog):
"""Test that output is not saved when event author differs from agent name."""
agent = LlmAgent(name="agent_a", output_key="result")
event = create_test_event(author="agent_b", content_text="Response from B")

with caplog.at_level("DEBUG"):
agent._LlmAgent__maybe_save_output_to_state(event)

# Should not add anything to state_delta
assert len(event.actions.state_delta) == 0

# Should log the skip
assert (
"Skipping output save for agent agent_a: event authored by agent_b"
in caplog.text
)

def test_maybe_save_output_to_state_saves_same_author(self):
"""Test that output is saved when event author matches agent name."""
agent = LlmAgent(name="test_agent", output_key="result")
event = create_test_event(author="test_agent", content_text="Test response")

agent._LlmAgent__maybe_save_output_to_state(event)

# Should save to state_delta
assert event.actions.state_delta["result"] == "Test response"

def test_maybe_save_output_to_state_no_output_key(self):
"""Test that nothing is saved when output_key is not set."""
agent = LlmAgent(name="test_agent") # No output_key
event = create_test_event(author="test_agent", content_text="Test response")

agent._LlmAgent__maybe_save_output_to_state(event)

# Should not save anything
assert len(event.actions.state_delta) == 0

def test_maybe_save_output_to_state_not_final_response(self):
"""Test that output is not saved for non-final responses."""
agent = LlmAgent(name="test_agent", output_key="result")
event = create_test_event(
author="test_agent", content_text="Partial response", is_final=False
)

agent._LlmAgent__maybe_save_output_to_state(event)

# Should not save partial responses
assert len(event.actions.state_delta) == 0

def test_maybe_save_output_to_state_no_content(self):
"""Test that nothing is saved when event has no content."""
agent = LlmAgent(name="test_agent", output_key="result")
event = create_test_event(author="test_agent", content_text="")

agent._LlmAgent__maybe_save_output_to_state(event)

# Should not save empty content
assert len(event.actions.state_delta) == 0

def test_maybe_save_output_to_state_with_output_schema(self):
"""Test that output is processed with schema when output_schema is set."""
agent = LlmAgent(
name="test_agent", output_key="result", output_schema=MockOutputSchema
)

# Create event with JSON content
json_content = '{"message": "Hello", "confidence": 0.95}'
event = create_test_event(author="test_agent", content_text=json_content)

agent._LlmAgent__maybe_save_output_to_state(event)

# Should save parsed and validated output
expected_output = {"message": "Hello", "confidence": 0.95}
assert event.actions.state_delta["result"] == expected_output

def test_maybe_save_output_to_state_multiple_parts(self):
"""Test that multiple text parts are concatenated."""
agent = LlmAgent(name="test_agent", output_key="result")

# Create event with multiple text parts
parts = [
types.Part.from_text(text="Hello "),
types.Part.from_text(text="world"),
types.Part.from_text(text="!"),
]
content = types.Content(role="model", parts=parts)

event = Event(
invocation_id="test_invocation",
author="test_agent",
content=content,
actions=EventActions(),
)

agent._LlmAgent__maybe_save_output_to_state(event)

# Should concatenate all text parts
assert event.actions.state_delta["result"] == "Hello world!"

def test_maybe_save_output_to_state_agent_transfer_scenario(self, caplog):
"""Test realistic agent transfer scenario."""
# Scenario: Agent A transfers to Agent B, Agent B produces output
# Agent A should not save Agent B's output

agent_a = LlmAgent(name="support_agent", output_key="support_result")
agent_b_event = create_test_event(
author="billing_agent", content_text="Your bill is $100"
)

with caplog.at_level("DEBUG"):
agent_a._LlmAgent__maybe_save_output_to_state(agent_b_event)

# Agent A should not save Agent B's output
assert len(agent_b_event.actions.state_delta) == 0
assert (
"Skipping output save for agent support_agent: event authored by"
" billing_agent"
in caplog.text
)

def test_maybe_save_output_to_state_case_sensitive_names(self, caplog):
"""Test that agent name comparison is case-sensitive."""
agent = LlmAgent(name="TestAgent", output_key="result")
event = create_test_event(author="testagent", content_text="Test response")

with caplog.at_level("DEBUG"):
agent._LlmAgent__maybe_save_output_to_state(event)

# Should not save due to case mismatch
assert len(event.actions.state_delta) == 0
assert (
"Skipping output save for agent TestAgent: event authored by testagent"
in caplog.text
)

@patch("google.adk.agents.llm_agent.logger")
def test_maybe_save_output_to_state_logging(self, mock_logger):
"""Test that debug logging works correctly."""
agent = LlmAgent(name="agent1", output_key="result")
event = create_test_event(author="agent2", content_text="Test response")

agent._LlmAgent__maybe_save_output_to_state(event)

# Should call logger.debug with correct parameters
mock_logger.debug.assert_called_once_with(
"Skipping output save for agent %s: event authored by %s",
"agent1",
"agent2",
)