Skip to content

feat: Allow toolset to process llm_request before tools returned by it #2013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from ...telemetry import trace_call_llm
from ...telemetry import trace_send_data
from ...telemetry import tracer
from ...tools.base_toolset import BaseToolset
from ...tools.tool_context import ToolContext

if TYPE_CHECKING:
Expand Down Expand Up @@ -341,13 +342,25 @@ async def _preprocess_async(
yield event

# Run processors for tools.
for tool in await agent.canonical_tools(
ReadonlyContext(invocation_context)
):
for tool_union in agent.tools:
tool_context = ToolContext(invocation_context)
await tool.process_llm_request(
tool_context=tool_context, llm_request=llm_request

# If it's a toolset, process it first
if isinstance(tool_union, BaseToolset):
await tool_union.process_llm_request(
tool_context=tool_context, llm_request=llm_request
)

from ...agents.llm_agent import _convert_tool_union_to_tools

# Then process all tools from this tool union
tools = await _convert_tool_union_to_tools(
tool_union, ReadonlyContext(invocation_context)
)
for tool in tools:
await tool.process_llm_request(
tool_context=tool_context, llm_request=llm_request
)

async def _postprocess_async(
self,
Expand Down
22 changes: 22 additions & 0 deletions src/google/adk/tools/base_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,16 @@
from typing import Optional
from typing import Protocol
from typing import runtime_checkable
from typing import TYPE_CHECKING
from typing import Union

from ..agents.readonly_context import ReadonlyContext
from .base_tool import BaseTool

if TYPE_CHECKING:
from ..models.llm_request import LlmRequest
from .tool_context import ToolContext


@runtime_checkable
class ToolPredicate(Protocol):
Expand Down Expand Up @@ -96,3 +101,20 @@ def _is_tool_selected(
return tool.name in self.tool_filter

return False

async def process_llm_request(
self, *, tool_context: ToolContext, llm_request: LlmRequest
) -> None:
"""Processes the outgoing LLM request for this toolset. This method will be
called before each tool processes the llm request.
Use cases:
- Instead of let each tool process the llm request, we can let the toolset
process the llm request. e.g. ComputerUseToolset can add computer use
tool to the llm request.
Args:
tool_context: The context of the tool.
llm_request: The outgoing LLM request, mutable this method.
"""
pass
150 changes: 150 additions & 0 deletions tests/unittests/flows/llm_flows/test_base_llm_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for BaseLlmFlow toolset integration."""

from unittest.mock import AsyncMock

from google.adk.agents import Agent
from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.adk.tools.base_toolset import BaseToolset
from google.genai import types
import pytest

from ... import testing_utils


class BaseLlmFlowForTesting(BaseLlmFlow):
"""Test implementation of BaseLlmFlow for testing purposes."""

pass


@pytest.mark.asyncio
async def test_preprocess_calls_toolset_process_llm_request():
"""Test that _preprocess_async calls process_llm_request on toolsets."""

# Create a mock toolset that tracks if process_llm_request was called
class _MockToolset(BaseToolset):

def __init__(self):
super().__init__()
self.process_llm_request_called = False
self.process_llm_request = AsyncMock(side_effect=self._track_call)

async def _track_call(self, **kwargs):
self.process_llm_request_called = True

async def get_tools(self, readonly_context=None):
return []

async def close(self):
pass

mock_toolset = _MockToolset()

# Create a mock model that returns a simple response
mock_response = LlmResponse(
content=types.Content(
role='model', parts=[types.Part.from_text(text='Test response')]
),
partial=False,
)

mock_model = testing_utils.MockModel.create(responses=[mock_response])

# Create agent with the mock toolset
agent = Agent(name='test_agent', model=mock_model, tools=[mock_toolset])
invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content='test message'
)

flow = BaseLlmFlowForTesting()

# Call _preprocess_async
llm_request = LlmRequest()
events = []
async for event in flow._preprocess_async(invocation_context, llm_request):
events.append(event)

# Verify that process_llm_request was called on the toolset
assert mock_toolset.process_llm_request_called


@pytest.mark.asyncio
async def test_preprocess_handles_mixed_tools_and_toolsets():
"""Test that _preprocess_async properly handles both tools and toolsets."""
from google.adk.tools.base_tool import BaseTool
from google.adk.tools.function_tool import FunctionTool

# Create a mock tool
class _MockTool(BaseTool):

def __init__(self):
super().__init__(name='mock_tool', description='Mock tool')
self.process_llm_request_called = False
self.process_llm_request = AsyncMock(side_effect=self._track_call)

async def _track_call(self, **kwargs):
self.process_llm_request_called = True

async def call(self, **kwargs):
return 'mock result'

# Create a mock toolset
class _MockToolset(BaseToolset):

def __init__(self):
super().__init__()
self.process_llm_request_called = False
self.process_llm_request = AsyncMock(side_effect=self._track_call)

async def _track_call(self, **kwargs):
self.process_llm_request_called = True

async def get_tools(self, readonly_context=None):
return []

async def close(self):
pass

def _test_function():
"""Test function tool."""
return 'function result'

mock_tool = _MockTool()
mock_toolset = _MockToolset()

# Create agent with mixed tools and toolsets
agent = Agent(
name='test_agent', tools=[mock_tool, _test_function, mock_toolset]
)

invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content='test message'
)

flow = BaseLlmFlowForTesting()

# Call _preprocess_async
llm_request = LlmRequest()
events = []
async for event in flow._preprocess_async(invocation_context, llm_request):
events.append(event)

# Verify that process_llm_request was called on both tools and toolsets
assert mock_tool.process_llm_request_called
assert mock_toolset.process_llm_request_called
109 changes: 109 additions & 0 deletions tests/unittests/tools/test_base_toolset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for BaseToolset."""

from typing import Optional

from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.readonly_context import ReadonlyContext
from google.adk.agents.sequential_agent import SequentialAgent
from google.adk.models.llm_request import LlmRequest
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.adk.tools.base_tool import BaseTool
from google.adk.tools.base_toolset import BaseToolset
from google.adk.tools.tool_context import ToolContext
import pytest


class _TestingToolset(BaseToolset):
"""A test implementation of BaseToolset."""

async def get_tools(
self, readonly_context: Optional[ReadonlyContext] = None
) -> list[BaseTool]:
return []

async def close(self) -> None:
pass


@pytest.mark.asyncio
async def test_process_llm_request_default_implementation():
"""Test that the default process_llm_request implementation does nothing."""
toolset = _TestingToolset()

# Create test objects
session_service = InMemorySessionService()
session = await session_service.create_session(
app_name='test_app', user_id='test_user'
)
agent = SequentialAgent(name='test_agent')
invocation_context = InvocationContext(
invocation_id='test_id',
agent=agent,
session=session,
session_service=session_service,
)
tool_context = ToolContext(invocation_context)
llm_request = LlmRequest()

# The default implementation should not modify the request
original_request = LlmRequest.model_validate(llm_request.model_dump())

await toolset.process_llm_request(
tool_context=tool_context, llm_request=llm_request
)

# Verify the request was not modified
assert llm_request.model_dump() == original_request.model_dump()


@pytest.mark.asyncio
async def test_process_llm_request_can_be_overridden():
"""Test that process_llm_request can be overridden by subclasses."""

class _CustomToolset(_TestingToolset):

async def process_llm_request(
self, *, tool_context: ToolContext, llm_request: LlmRequest
) -> None:
# Add some custom processing
if not llm_request.contents:
llm_request.contents = []
llm_request.contents.append('Custom processing applied')

toolset = _CustomToolset()

# Create test objects
session_service = InMemorySessionService()
session = await session_service.create_session(
app_name='test_app', user_id='test_user'
)
agent = SequentialAgent(name='test_agent')
invocation_context = InvocationContext(
invocation_id='test_id',
agent=agent,
session=session,
session_service=session_service,
)
tool_context = ToolContext(invocation_context)
llm_request = LlmRequest()

await toolset.process_llm_request(
tool_context=tool_context, llm_request=llm_request
)

# Verify the custom processing was applied
assert llm_request.contents == ['Custom processing applied']