Skip to content
68 changes: 40 additions & 28 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from google.adk.platform import time as platform_time
from google.genai import types
from opentelemetry import trace
from websockets.exceptions import ConnectionClosed
from websockets.exceptions import ConnectionClosedOK

Expand Down Expand Up @@ -1102,28 +1103,47 @@ async def _call_llm_async(
llm_request: LlmRequest,
model_response_event: Event,
) -> AsyncGenerator[LlmResponse, None]:
# Runs before_model_callback if it exists.
if response := await self._handle_before_model_callback(
invocation_context, llm_request, model_response_event
):
yield response
return

llm_request.config = llm_request.config or types.GenerateContentConfig()
llm_request.config.labels = llm_request.config.labels or {}
async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
with tracer.start_as_current_span('call_llm') as span:
# Runs before_model_callback if it exists.
# This must be inside the call_llm span so that before_model_callback
# and after_model_callback/on_model_error_callback all share the same
# span context (fixes issue #4851).
if response := await self._handle_before_model_callback(
invocation_context, llm_request, model_response_event
):
yield response
return

llm_request.config = llm_request.config or types.GenerateContentConfig()
llm_request.config.labels = llm_request.config.labels or {}

# Add agent name as a label to the llm_request. This will help with
# slicing the billing reports on a per-agent basis.
if _ADK_AGENT_NAME_LABEL_KEY not in llm_request.config.labels:
llm_request.config.labels[_ADK_AGENT_NAME_LABEL_KEY] = (
invocation_context.agent.name
)

# Add agent name as a label to the llm_request. This will help with slicing
# the billing reports on a per-agent basis.
if _ADK_AGENT_NAME_LABEL_KEY not in llm_request.config.labels:
llm_request.config.labels[_ADK_AGENT_NAME_LABEL_KEY] = (
invocation_context.agent.name
)
# Calls the LLM.
llm = self.__get_llm(invocation_context)

# Calls the LLM.
llm = self.__get_llm(invocation_context)
async def _apply_after_model_callback(
response: LlmResponse,
) -> LlmResponse:
"""Applies after_model_callback within the call_llm span context.

Re-activates the call_llm span so after_model_callback sees the
same span_id as before_model_callback (issue #4851).
"""
with trace.use_span(span):
if altered := await self._handle_after_model_callback(
invocation_context, response, model_response_event
):
return altered
return response

async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
with tracer.start_as_current_span('call_llm') as span:
if invocation_context.run_config.support_cfc:
invocation_context.live_request_queue = LiveRequestQueue()
responses_generator = self.run_live(invocation_context)
Expand All @@ -1136,11 +1156,7 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
)
) as agen:
async for llm_response in agen:
# Runs after_model_callback if it exists.
if altered_llm_response := await self._handle_after_model_callback(
invocation_context, llm_response, model_response_event
):
llm_response = altered_llm_response
llm_response = await _apply_after_model_callback(llm_response)
# only yield partial response in SSE streaming mode
if (
invocation_context.run_config.streaming_mode
Expand Down Expand Up @@ -1176,11 +1192,7 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
llm_response,
span,
)
# Runs after_model_callback if it exists.
if altered_llm_response := await self._handle_after_model_callback(
invocation_context, llm_response, model_response_event
):
llm_response = altered_llm_response
llm_response = await _apply_after_model_callback(llm_response)

yield llm_response

Expand Down
232 changes: 232 additions & 0 deletions tests/unittests/flows/llm_flows/test_llm_callback_span_consistency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests that LLM callbacks share the same OTel span context (issue #4851).

When OpenTelemetry tracing is enabled, before_model_callback,
after_model_callback, and on_model_error_callback must all execute within
the same call_llm span so that plugins (e.g. BigQueryAgentAnalyticsPlugin)
see a consistent span_id for LLM_REQUEST and LLM_RESPONSE events.
"""

from typing import Optional
from unittest import mock

from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.llm_agent import Agent
from google.adk.flows.llm_flows import base_llm_flow
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.adk.plugins.base_plugin import BasePlugin
from google.adk.telemetry import tracing as adk_tracing
from google.genai import types
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
import pytest

from ... import testing_utils


def _make_real_tracer():
"""Create a real tracer that produces valid span IDs."""
provider = TracerProvider()
return provider.get_tracer('test_tracer')


class SpanCapturingPlugin(BasePlugin):
"""Plugin that captures the current span ID in each model callback."""

def __init__(self):
super().__init__(name='span_capturing_plugin')
self.before_model_span_id: Optional[int] = None
self.after_model_span_id: Optional[int] = None
self.on_model_error_span_id: Optional[int] = None

async def before_model_callback(
self,
*,
callback_context: CallbackContext,
llm_request: LlmRequest,
) -> Optional[LlmResponse]:
span = trace.get_current_span()
ctx = span.get_span_context()
if ctx and ctx.span_id:
self.before_model_span_id = ctx.span_id
return None

async def after_model_callback(
self,
*,
callback_context: CallbackContext,
llm_response: LlmResponse,
) -> Optional[LlmResponse]:
span = trace.get_current_span()
ctx = span.get_span_context()
if ctx and ctx.span_id:
self.after_model_span_id = ctx.span_id
return None

async def on_model_error_callback(
self,
*,
callback_context: CallbackContext,
llm_request: LlmRequest,
error: Exception,
) -> Optional[LlmResponse]:
span = trace.get_current_span()
ctx = span.get_span_context()
if ctx and ctx.span_id:
self.on_model_error_span_id = ctx.span_id
return LlmResponse(
content=testing_utils.ModelContent(
[types.Part.from_text(text='error handled')]
)
)


@pytest.mark.asyncio
async def test_before_and_after_model_callbacks_share_span_id():
"""Verify before_model_callback and after_model_callback share the same span.

This is the core regression test for issue #4851. Before the fix,
before_model_callback ran outside the call_llm span, causing a span_id
mismatch between LLM_REQUEST and LLM_RESPONSE events.
"""
plugin = SpanCapturingPlugin()
real_tracer = _make_real_tracer()

mock_model = testing_utils.MockModel.create(responses=['model_response'])
agent = Agent(
name='test_agent',
model=mock_model,
)

with (
mock.patch.object(base_llm_flow, 'tracer', real_tracer),
mock.patch.object(adk_tracing, 'tracer', real_tracer),
):
runner = testing_utils.TestInMemoryRunner(agent, plugins=[plugin])
events = await runner.run_async_with_new_session('test')

# Both callbacks should have captured a span ID
assert (
plugin.before_model_span_id is not None
), 'before_model_callback did not capture a span ID'
assert (
plugin.after_model_span_id is not None
), 'after_model_callback did not capture a span ID'

# The span IDs must match — this is the core assertion for issue #4851
assert plugin.before_model_span_id == plugin.after_model_span_id, (
'Span ID mismatch: before_model_callback span_id='
f'{plugin.before_model_span_id:#018x}, '
f'after_model_callback span_id={plugin.after_model_span_id:#018x}. '
'Both callbacks must run inside the same call_llm span.'
)


@pytest.mark.asyncio
async def test_before_and_on_error_model_callbacks_share_span_id():
"""Verify before_model_callback and on_model_error_callback share span.

When the model raises an error, on_model_error_callback should see the
same span as before_model_callback.
"""
plugin = SpanCapturingPlugin()
real_tracer = _make_real_tracer()

mock_model = testing_utils.MockModel.create(
responses=[], error=SystemError('model error')
)
agent = Agent(
name='test_agent',
model=mock_model,
)

with (
mock.patch.object(base_llm_flow, 'tracer', real_tracer),
mock.patch.object(adk_tracing, 'tracer', real_tracer),
):
runner = testing_utils.TestInMemoryRunner(agent, plugins=[plugin])
events = await runner.run_async_with_new_session('test')

# Both callbacks should have captured a span ID
assert (
plugin.before_model_span_id is not None
), 'before_model_callback did not capture a span ID'
assert (
plugin.on_model_error_span_id is not None
), 'on_model_error_callback did not capture a span ID'

# The span IDs must match
assert plugin.before_model_span_id == plugin.on_model_error_span_id, (
'Span ID mismatch: before_model_callback span_id='
f'{plugin.before_model_span_id:#018x}, '
'on_model_error_callback span_id='
f'{plugin.on_model_error_span_id:#018x}. '
'Both callbacks must run inside the same call_llm span.'
)


@pytest.mark.asyncio
async def test_before_model_callback_short_circuit_has_span():
"""Verify before_model_callback has a valid span when short-circuiting."""

class ShortCircuitPlugin(BasePlugin):

def __init__(self):
super().__init__(name='short_circuit_plugin')
self.span_id: Optional[int] = None

async def before_model_callback(
self,
*,
callback_context: CallbackContext,
llm_request: LlmRequest,
) -> Optional[LlmResponse]:
span = trace.get_current_span()
ctx = span.get_span_context()
if ctx and ctx.span_id:
self.span_id = ctx.span_id
return LlmResponse(
content=testing_utils.ModelContent(
[types.Part.from_text(text='short-circuited')]
)
)

plugin = ShortCircuitPlugin()
real_tracer = _make_real_tracer()

mock_model = testing_utils.MockModel.create(responses=['model_response'])
agent = Agent(
name='test_agent',
model=mock_model,
)

with (
mock.patch.object(base_llm_flow, 'tracer', real_tracer),
mock.patch.object(adk_tracing, 'tracer', real_tracer),
):
runner = testing_utils.TestInMemoryRunner(agent, plugins=[plugin])
events = await runner.run_async_with_new_session('test')

# The callback should have a valid (non-zero) span ID from the call_llm span
assert plugin.span_id is not None and plugin.span_id != 0, (
'before_model_callback should have a valid span ID even when '
'short-circuiting the LLM call'
)

# Verify the short-circuit response was received
simplified = testing_utils.simplify_events(events)
assert any('short-circuited' in str(e) for e in simplified)