Skip to content

Commit 7deffb1

Browse files
wukathcopybara-github
authored andcommitted
fix: pass tool context into require_confirmation function in McpTool
Aligns with the FunctionTool implementation of require_confirmation. This fixes #4327. Co-authored-by: Kathy Wu <[email protected]> PiperOrigin-RevId: 865566362
1 parent ac1401b commit 7deffb1

File tree

2 files changed

+69
-1
lines changed

2 files changed

+69
-1
lines changed

src/google/adk/tools/mcp_tool/mcp_tool.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,31 @@ async def run_async(
151151
self, *, args: dict[str, Any], tool_context: ToolContext
152152
) -> Any:
153153
if isinstance(self._require_confirmation, Callable):
154+
args_to_call = args.copy()
155+
try:
156+
signature = inspect.signature(self._require_confirmation)
157+
valid_params = set(signature.parameters.keys())
158+
has_kwargs = any(
159+
param.kind == inspect.Parameter.VAR_KEYWORD
160+
for param in signature.parameters.values()
161+
)
162+
163+
if "tool_context" in valid_params or has_kwargs:
164+
args_to_call["tool_context"] = tool_context
165+
166+
# Filter args_to_call only if there's no **kwargs
167+
if not has_kwargs:
168+
# Add tool_context to valid_params if it was added to args_to_call
169+
if "tool_context" in args_to_call:
170+
valid_params.add("tool_context")
171+
args_to_call = {
172+
k: v for k, v in args_to_call.items() if k in valid_params
173+
}
174+
except ValueError:
175+
args_to_call = args
176+
154177
require_confirmation = await self._invoke_callable(
155-
self._require_confirmation, args
178+
self._require_confirmation, args_to_call
156179
)
157180
else:
158181
require_confirmation = bool(self._require_confirmation)

tests/unittests/tools/mcp_tool/test_mcp_tool.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import inspect
1516
from unittest.mock import AsyncMock
1617
from unittest.mock import Mock
1718
from unittest.mock import patch
@@ -669,6 +670,50 @@ async def test_run_async_require_confirmation_true_confirmed(self):
669670
args=args, tool_context=tool_context
670671
)
671672

673+
@pytest.mark.asyncio
674+
async def test_run_async_require_confirmation_callable_with_arg_filtering(
675+
self,
676+
):
677+
"""Test require_confirmation=callable with argument filtering."""
678+
679+
async def _require_confirmation_func(
680+
param1: str, tool_context: ToolContext
681+
):
682+
return True
683+
684+
tool = MCPTool(
685+
mcp_tool=self.mock_mcp_tool,
686+
mcp_session_manager=self.mock_session_manager,
687+
require_confirmation=_require_confirmation_func,
688+
)
689+
tool_context = Mock(spec=ToolContext)
690+
tool_context.tool_confirmation = None
691+
tool_context.request_confirmation = Mock()
692+
args = {"param1": "test_value", "extra_arg": 123}
693+
694+
with patch.object(
695+
tool, "_invoke_callable", new_callable=AsyncMock
696+
) as mock_invoke_callable:
697+
mock_invoke_callable.return_value = (
698+
True # Mock the return of require_confirmation
699+
)
700+
701+
result = await tool.run_async(args=args, tool_context=tool_context)
702+
expected_args_to_call = {
703+
"param1": "test_value",
704+
"tool_context": tool_context,
705+
}
706+
mock_invoke_callable.assert_called_once_with(
707+
_require_confirmation_func, expected_args_to_call
708+
)
709+
710+
assert result == {
711+
"error": (
712+
"This tool call requires confirmation, please approve or reject."
713+
)
714+
}
715+
tool_context.request_confirmation.assert_called_once()
716+
672717
@pytest.mark.asyncio
673718
async def test_run_async_require_confirmation_callable_true_no_confirmation(
674719
self,

0 commit comments

Comments
 (0)