diff --git a/packages/toolbox-core/src/toolbox_core/mcp_transport/mcp.py b/packages/toolbox-core/src/toolbox_core/mcp_transport/mcp.py index e944b998..64b601ec 100644 --- a/packages/toolbox-core/src/toolbox_core/mcp_transport/mcp.py +++ b/packages/toolbox-core/src/toolbox_core/mcp_transport/mcp.py @@ -57,7 +57,24 @@ def base_url(self) -> str: return self._mcp_base_url def _convert_tool_schema(self, tool_data: dict) -> ToolSchema: - """Converts a raw MCP tool dictionary into the Toolbox ToolSchema.""" + """ + Safely converts the raw tool dictionary from the server into a ToolSchema object, + robustly handling optional authentication metadata. + """ + param_auth = None + invoke_auth = [] + + if "_meta" in tool_data and isinstance(tool_data["_meta"], dict): + meta = tool_data["_meta"] + if "toolbox/authParam" in meta and isinstance( + meta["toolbox/authParam"], dict + ): + param_auth = meta["toolbox/authParam"] + if "toolbox/authInvoke" in meta and isinstance( + meta["toolbox/authInvoke"], list + ): + invoke_auth = meta["toolbox/authInvoke"] + parameters = [] input_schema = tool_data.get("inputSchema", {}) properties = input_schema.get("properties", {}) @@ -71,6 +88,10 @@ def _convert_tool_schema(self, tool_data: dict) -> ToolSchema: ) else: additional_props = True + if param_auth and name in param_auth: + auth_sources = param_auth[name] + else: + auth_sources = None parameters.append( ParameterSchema( name=name, @@ -78,11 +99,14 @@ def _convert_tool_schema(self, tool_data: dict) -> ToolSchema: description=schema.get("description", ""), required=name in required, additionalProperties=additional_props, + authSources=auth_sources, ) ) return ToolSchema( - description=tool_data.get("description") or "", parameters=parameters + description=tool_data.get("description") or "", + parameters=parameters, + authRequired=invoke_auth, ) async def close(self): diff --git a/packages/toolbox-core/tests/mcp_transport/test_mcp.py b/packages/toolbox-core/tests/mcp_transport/test_mcp.py index dc85d0da..7d1baa14 100644 --- a/packages/toolbox-core/tests/mcp_transport/test_mcp.py +++ b/packages/toolbox-core/tests/mcp_transport/test_mcp.py @@ -162,6 +162,125 @@ def test_convert_tool_schema_complex_types(self, transport): assert p_obj.type == "object" assert p_obj.additionalProperties.type == "integer" + def test_convert_tool_schema_with_auth_params(self, transport): + raw_tool = { + "name": "auth_tool", + "description": "Tool with auth params", + "_meta": {"toolbox/authParam": {"api_key": ["header", "X-API-Key"]}}, + "inputSchema": { + "type": "object", + "properties": { + "api_key": {"type": "string"}, + "other_param": {"type": "string"}, + }, + }, + } + + schema = transport._convert_tool_schema(raw_tool) + api_key_param = next(p for p in schema.parameters if p.name == "api_key") + assert api_key_param.authSources == ["header", "X-API-Key"] + other_param = next(p for p in schema.parameters if p.name == "other_param") + assert other_param.authSources is None + + def test_convert_tool_schema_with_auth_invoke(self, transport): + raw_tool = { + "name": "invoke_auth_tool", + "description": "Tool requiring invocation auth", + "_meta": {"toolbox/authInvoke": ["Bearer", "OAuth2"]}, + "inputSchema": {"type": "object", "properties": {}}, + } + + schema = transport._convert_tool_schema(raw_tool) + + assert schema.authRequired == ["Bearer", "OAuth2"] + + def test_convert_tool_schema_multiple_auth_services(self, transport): + """ + Test where a single parameter requires multiple/complex auth definitions, + or multiple parameters have distinct auth requirements. + """ + raw_tool = { + "name": "multi_auth_tool", + "description": "Tool with multiple auth params", + "_meta": { + "toolbox/authParam": { + "service_a_key": ["header", "X-Service-A-Key"], + "service_b_token": ["header", "X-Service-B-Token"], + } + }, + "inputSchema": { + "type": "object", + "properties": { + "service_a_key": {"type": "string"}, + "service_b_token": {"type": "string"}, + "regular_param": {"type": "string"}, + }, + }, + } + + schema = transport._convert_tool_schema(raw_tool) + param_a = next(p for p in schema.parameters if p.name == "service_a_key") + assert param_a.authSources == ["header", "X-Service-A-Key"] + param_b = next(p for p in schema.parameters if p.name == "service_b_token") + assert param_b.authSources == ["header", "X-Service-B-Token"] + regular = next(p for p in schema.parameters if p.name == "regular_param") + assert regular.authSources is None + + def test_convert_tool_schema_mixed_auth_same_name(self, transport): + """ + Test both toolbox/authParam and toolbox/authInvoke present, + using the SAME auth definition (e.g., same Bearer token used for both). + """ + raw_tool = { + "name": "mixed_auth_same_tool", + "description": "Tool with overlapping auth requirements", + "_meta": { + "toolbox/authInvoke": ["Bearer", "SharedToken"], + "toolbox/authParam": {"auth_token": ["Bearer", "SharedToken"]}, + }, + "inputSchema": { + "type": "object", + "properties": { + "auth_token": {"type": "string"}, + "query": {"type": "string"}, + }, + }, + } + + schema = transport._convert_tool_schema(raw_tool) + assert schema.authRequired == ["Bearer", "SharedToken"] + param_auth = next(p for p in schema.parameters if p.name == "auth_token") + assert param_auth.authSources == ["Bearer", "SharedToken"] + + def test_convert_tool_schema_mixed_auth_different_names(self, transport): + """ + Test both toolbox/authParam and toolbox/authInvoke present, + but with DIFFERENT auth definitions (e.g. OAuth for the tool, API Key for a specific param). + """ + raw_tool = { + "name": "mixed_auth_diff_tool", + "description": "Tool with distinct auth requirements", + "_meta": { + "toolbox/authInvoke": ["Bearer", "GoogleOAuth"], + "toolbox/authParam": {"third_party_key": ["header", "X-3rd-Party-Key"]}, + }, + "inputSchema": { + "type": "object", + "properties": { + "third_party_key": {"type": "string"}, + "user_query": {"type": "string"}, + }, + }, + } + + schema = transport._convert_tool_schema(raw_tool) + assert schema.authRequired == ["Bearer", "GoogleOAuth"] + param_auth = next(p for p in schema.parameters if p.name == "third_party_key") + assert param_auth.authSources == ["header", "X-3rd-Party-Key"] + + param_normal = next(p for p in schema.parameters if p.name == "user_query") + assert param_normal.authSources is None + @pytest.mark.asyncio async def test_close_managed_session(self, mocker): mock_close = mocker.patch("aiohttp.ClientSession.close", new_callable=AsyncMock) diff --git a/packages/toolbox-core/tests/test_e2e_mcp.py b/packages/toolbox-core/tests/test_e2e_mcp.py index d4c64e21..3a7e6102 100644 --- a/packages/toolbox-core/tests/test_e2e_mcp.py +++ b/packages/toolbox-core/tests/test_e2e_mcp.py @@ -133,6 +133,99 @@ async def test_bind_params_callable( assert "row4" not in response +@pytest.mark.asyncio +@pytest.mark.usefixtures("toolbox_server") +class TestAuth: + async def test_run_tool_unauth_with_auth( + self, toolbox: ToolboxClient, auth_token2: str + ): + """Tests running a tool that doesn't require auth, with auth provided.""" + + with pytest.raises( + ValueError, + match=rf"Validation failed for tool 'get-row-by-id': unused auth tokens: my-test-auth", + ): + await toolbox.load_tool( + "get-row-by-id", + auth_token_getters={"my-test-auth": lambda: auth_token2}, + ) + + async def test_run_tool_no_auth(self, toolbox: ToolboxClient): + """Tests running a tool requiring auth without providing auth.""" + tool = await toolbox.load_tool("get-row-by-id-auth") + with pytest.raises( + PermissionError, + match="One or more of the following authn services are required to invoke this tool: my-test-auth", + ): + await tool(id="2") + + async def test_run_tool_wrong_auth(self, toolbox: ToolboxClient, auth_token2: str): + """Tests running a tool with incorrect auth. The tool + requires a different authentication than the one provided.""" + tool = await toolbox.load_tool("get-row-by-id-auth") + auth_tool = tool.add_auth_token_getters({"my-test-auth": lambda: auth_token2}) + with pytest.raises( + Exception, + match="tool invocation not authorized. Please make sure your specify correct auth headers", + ): + await auth_tool(id="2") + + async def test_run_tool_auth(self, toolbox: ToolboxClient, auth_token1: str): + """Tests running a tool with correct auth.""" + tool = await toolbox.load_tool("get-row-by-id-auth") + auth_tool = tool.add_auth_token_getters({"my-test-auth": lambda: auth_token1}) + response = await auth_tool(id="2") + assert "row2" in response + + @pytest.mark.asyncio + async def test_run_tool_async_auth(self, toolbox: ToolboxClient, auth_token1: str): + """Tests running a tool with correct auth using an async token getter.""" + tool = await toolbox.load_tool("get-row-by-id-auth") + + async def get_token_asynchronously(): + return auth_token1 + + auth_tool = tool.add_auth_token_getters( + {"my-test-auth": get_token_asynchronously} + ) + response = await auth_tool(id="2") + assert "row2" in response + + async def test_run_tool_param_auth_no_auth(self, toolbox: ToolboxClient): + """Tests running a tool with a param requiring auth, without auth.""" + tool = await toolbox.load_tool("get-row-by-email-auth") + with pytest.raises( + PermissionError, + match="One or more of the following authn services are required to invoke this tool: my-test-auth", + ): + await tool() + + async def test_run_tool_param_auth(self, toolbox: ToolboxClient, auth_token1: str): + """Tests running a tool with a param requiring auth, with correct auth.""" + tool = await toolbox.load_tool( + "get-row-by-email-auth", + auth_token_getters={"my-test-auth": lambda: auth_token1}, + ) + response = await tool() + assert "row4" in response + assert "row5" in response + assert "row6" in response + + async def test_run_tool_param_auth_no_field( + self, toolbox: ToolboxClient, auth_token1: str + ): + """Tests running a tool with a param requiring auth, with insufficient auth.""" + tool = await toolbox.load_tool( + "get-row-by-content-auth", + auth_token_getters={"my-test-auth": lambda: auth_token1}, + ) + with pytest.raises( + Exception, + match="no field named row_data in claims", + ): + await tool() + + @pytest.mark.asyncio @pytest.mark.usefixtures("toolbox_server") class TestOptionalParams: