Skip to content

Commit 7793a27

Browse files
authored
refactor: move transport logic to a ToolboxTransport class (#344)
* add basic code * fixes * test fix * new unit tests * rename ToolboxTransport * add py3.9 support * fix langchain tool tests * test fix * lint * fix tests * move manage session into transport * move warning to diff file * avoid code duplication * fix tests * lint * remove redundant tests * make invoke method return str * lint * fix return type * small refactor * rename private method * fix tests * lint
1 parent 023a7eb commit 7793a27

File tree

8 files changed

+472
-483
lines changed

8 files changed

+472
-483
lines changed

packages/toolbox-core/src/toolbox_core/client.py

Lines changed: 10 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
from aiohttp import ClientSession
2020
from deprecated import deprecated
2121

22-
from .protocol import ManifestSchema, ToolSchema
22+
from .itransport import ITransport
23+
from .protocol import ToolSchema
2324
from .tool import ToolboxTool
25+
from .toolbox_transport import ToolboxTransport
2426
from .utils import identify_auth_requirements, resolve_value
2527

2628

@@ -33,9 +35,7 @@ class ToolboxClient:
3335
is not provided.
3436
"""
3537

36-
__base_url: str
37-
__session: ClientSession
38-
__manage_session: bool
38+
__transport: ITransport
3939

4040
def __init__(
4141
self,
@@ -56,15 +56,8 @@ def __init__(
5656
should typically be managed externally.
5757
client_headers: Headers to include in each request sent through this client.
5858
"""
59-
self.__base_url = url
60-
61-
# If no aiohttp.ClientSession is provided, make our own
62-
self.__manage_session = False
63-
if session is None:
64-
self.__manage_session = True
65-
session = ClientSession()
66-
self.__session = session
6759

60+
self.__transport = ToolboxTransport(url, session)
6861
self.__client_headers = client_headers if client_headers is not None else {}
6962

7063
def __parse_tool(
@@ -103,8 +96,7 @@ def __parse_tool(
10396
)
10497

10598
tool = ToolboxTool(
106-
session=self.__session,
107-
base_url=self.__base_url,
99+
transport=self.__transport,
108100
name=name,
109101
description=schema.description,
110102
# create a read-only values to prevent mutation
@@ -149,8 +141,7 @@ async def close(self):
149141
If the session was provided externally during initialization, the caller
150142
is responsible for its lifecycle.
151143
"""
152-
if self.__manage_session and not self.__session.closed:
153-
await self.__session.close()
144+
await self.__transport.close()
154145

155146
async def load_tool(
156147
self,
@@ -191,16 +182,7 @@ async def load_tool(
191182
for name, val in self.__client_headers.items()
192183
}
193184

194-
# request the definition of the tool from the server
195-
url = f"{self.__base_url}/api/tool/{name}"
196-
async with self.__session.get(url, headers=resolved_headers) as response:
197-
if not response.ok:
198-
error_text = await response.text()
199-
raise RuntimeError(
200-
f"API request failed with status {response.status} ({response.reason}). Server response: {error_text}"
201-
)
202-
json = await response.json()
203-
manifest: ManifestSchema = ManifestSchema(**json)
185+
manifest = await self.__transport.tool_get(name, resolved_headers)
204186

205187
# parse the provided definition to a tool
206188
if name not in manifest.tools:
@@ -274,16 +256,8 @@ async def load_toolset(
274256
header_name: await resolve_value(original_headers[header_name])
275257
for header_name in original_headers
276258
}
277-
# Request the definition of the toolset from the server
278-
url = f"{self.__base_url}/api/toolset/{name or ''}"
279-
async with self.__session.get(url, headers=resolved_headers) as response:
280-
if not response.ok:
281-
error_text = await response.text()
282-
raise RuntimeError(
283-
f"API request failed with status {response.status} ({response.reason}). Server response: {error_text}"
284-
)
285-
json = await response.json()
286-
manifest: ManifestSchema = ManifestSchema(**json)
259+
260+
manifest = await self.__transport.tools_list(name, resolved_headers)
287261

288262
tools: list[ToolboxTool] = []
289263
overall_used_auth_keys: set[str] = set()
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from abc import ABC, abstractmethod
16+
from typing import Mapping, Optional
17+
18+
from .protocol import ManifestSchema
19+
20+
21+
class ITransport(ABC):
22+
"""Defines the contract for a 'smart' transport that handles both
23+
protocol formatting and network communication.
24+
"""
25+
26+
@property
27+
@abstractmethod
28+
def base_url(self) -> str:
29+
"""The base URL for the transport."""
30+
pass
31+
32+
@abstractmethod
33+
async def tool_get(
34+
self, tool_name: str, headers: Optional[Mapping[str, str]] = None
35+
) -> ManifestSchema:
36+
"""Gets a single tool from the server."""
37+
pass
38+
39+
@abstractmethod
40+
async def tools_list(
41+
self,
42+
toolset_name: Optional[str] = None,
43+
headers: Optional[Mapping[str, str]] = None,
44+
) -> ManifestSchema:
45+
"""Lists available tools from the server."""
46+
pass
47+
48+
@abstractmethod
49+
async def tool_invoke(
50+
self, tool_name: str, arguments: dict, headers: Mapping[str, str]
51+
) -> str:
52+
"""Invokes a specific tool on the server."""
53+
pass
54+
55+
@abstractmethod
56+
async def close(self):
57+
"""Closes any underlying connections."""
58+
pass

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 12 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
from typing import Any, Awaitable, Callable, Mapping, Optional, Sequence, Union
2121
from warnings import warn
2222

23-
from aiohttp import ClientSession
24-
23+
from .itransport import ITransport
2524
from .protocol import ParameterSchema
2625
from .utils import (
2726
create_func_docstring,
@@ -46,8 +45,7 @@ class ToolboxTool:
4645

4746
def __init__(
4847
self,
49-
session: ClientSession,
50-
base_url: str,
48+
transport: ITransport,
5149
name: str,
5250
description: str,
5351
params: Sequence[ParameterSchema],
@@ -68,8 +66,7 @@ def __init__(
6866
Toolbox server.
6967
7068
Args:
71-
session: The `aiohttp.ClientSession` used for making API requests.
72-
base_url: The base URL of the Toolbox server API.
69+
transport: The transport used for making API requests.
7370
name: The name of the remote tool.
7471
description: The description of the remote tool.
7572
params: The args of the tool.
@@ -84,9 +81,7 @@ def __init__(
8481
client_headers: Client specific headers bound to the tool.
8582
"""
8683
# used to invoke the toolbox API
87-
self.__session: ClientSession = session
88-
self.__base_url: str = base_url
89-
self.__url = f"{base_url}/api/tool/{name}/invoke"
84+
self.__transport = transport
9085
self.__description = description
9186
self.__params = params
9287
self.__pydantic_model = params_to_pydantic_model(name, self.__params)
@@ -120,17 +115,6 @@ def __init__(
120115
# map of client headers to their value/callable/coroutine
121116
self.__client_headers = client_headers
122117

123-
# ID tokens contain sensitive user information (claims). Transmitting
124-
# these over HTTP exposes the data to interception and unauthorized
125-
# access. Always use HTTPS to ensure secure communication and protect
126-
# user privacy.
127-
if (
128-
required_authn_params or required_authz_tokens or client_headers
129-
) and not self.__url.startswith("https://"):
130-
warn(
131-
"Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication."
132-
)
133-
134118
@property
135119
def _name(self) -> str:
136120
return self.__name__
@@ -171,8 +155,7 @@ def _client_headers(
171155

172156
def __copy(
173157
self,
174-
session: Optional[ClientSession] = None,
175-
base_url: Optional[str] = None,
158+
transport: Optional[ITransport] = None,
176159
name: Optional[str] = None,
177160
description: Optional[str] = None,
178161
params: Optional[Sequence[ParameterSchema]] = None,
@@ -192,8 +175,7 @@ def __copy(
192175
Creates a copy of the ToolboxTool, overriding specific fields.
193176
194177
Args:
195-
session: The `aiohttp.ClientSession` used for making API requests.
196-
base_url: The base URL of the Toolbox server API.
178+
transport: The transport used for making API requests.
197179
name: The name of the remote tool.
198180
description: The description of the remote tool.
199181
params: The args of the tool.
@@ -209,8 +191,7 @@ def __copy(
209191
"""
210192
check = lambda val, default: val if val is not None else default
211193
return ToolboxTool(
212-
session=check(session, self.__session),
213-
base_url=check(base_url, self.__base_url),
194+
transport=check(transport, self.__transport),
214195
name=check(name, self.__name__),
215196
description=check(description, self.__description),
216197
params=check(params, self.__params),
@@ -291,16 +272,11 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
291272
token_getter
292273
)
293274

294-
async with self.__session.post(
295-
self.__url,
296-
json=payload,
297-
headers=headers,
298-
) as resp:
299-
body = await resp.json()
300-
if not resp.ok:
301-
err = body.get("error", f"unexpected status from server: {resp.status}")
302-
raise Exception(err)
303-
return body.get("result", body)
275+
return await self.__transport.tool_invoke(
276+
self.__name__,
277+
payload,
278+
headers,
279+
)
304280

305281
def add_auth_token_getters(
306282
self,
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Mapping, Optional
16+
from warnings import warn
17+
18+
from aiohttp import ClientSession
19+
20+
from .itransport import ITransport
21+
from .protocol import ManifestSchema
22+
23+
24+
class ToolboxTransport(ITransport):
25+
"""Transport for the native Toolbox protocol."""
26+
27+
def __init__(self, base_url: str, session: Optional[ClientSession]):
28+
self.__base_url = base_url
29+
30+
# If no aiohttp.ClientSession is provided, make our own
31+
self.__manage_session = False
32+
if session is not None:
33+
self.__session = session
34+
else:
35+
self.__manage_session = True
36+
self.__session = ClientSession()
37+
38+
@property
39+
def base_url(self) -> str:
40+
"""The base URL for the transport."""
41+
return self.__base_url
42+
43+
async def __get_manifest(
44+
self, url: str, headers: Optional[Mapping[str, str]]
45+
) -> ManifestSchema:
46+
"""Helper method to perform GET requests and parse the ManifestSchema."""
47+
async with self.__session.get(url, headers=headers) as response:
48+
if not response.ok:
49+
error_text = await response.text()
50+
raise RuntimeError(
51+
f"API request failed with status {response.status} ({response.reason}). Server response: {error_text}"
52+
)
53+
json = await response.json()
54+
return ManifestSchema(**json)
55+
56+
async def tool_get(
57+
self, tool_name: str, headers: Optional[Mapping[str, str]] = None
58+
) -> ManifestSchema:
59+
url = f"{self.__base_url}/api/tool/{tool_name}"
60+
return await self.__get_manifest(url, headers)
61+
62+
async def tools_list(
63+
self,
64+
toolset_name: Optional[str] = None,
65+
headers: Optional[Mapping[str, str]] = None,
66+
) -> ManifestSchema:
67+
url = f"{self.__base_url}/api/toolset/{toolset_name or ''}"
68+
return await self.__get_manifest(url, headers)
69+
70+
async def tool_invoke(
71+
self, tool_name: str, arguments: dict, headers: Mapping[str, str]
72+
) -> str:
73+
# ID tokens contain sensitive user information (claims). Transmitting
74+
# these over HTTP exposes the data to interception and unauthorized
75+
# access. Always use HTTPS to ensure secure communication and protect
76+
# user privacy.
77+
if self.base_url.startswith("http://") and headers:
78+
warn(
79+
"Sending data token over HTTP. User data may be exposed. Use HTTPS for secure communication."
80+
)
81+
url = f"{self.__base_url}/api/tool/{tool_name}/invoke"
82+
async with self.__session.post(
83+
url,
84+
json=arguments,
85+
headers=headers,
86+
) as resp:
87+
body = await resp.json()
88+
if not resp.ok:
89+
err = body.get("error", f"unexpected status from server: {resp.status}")
90+
raise Exception(err)
91+
return body.get("result")
92+
93+
async def close(self):
94+
if self.__manage_session and not self.__session.closed:
95+
await self.__session.close()

0 commit comments

Comments
 (0)