Skip to content

Commit 9725c44

Browse files
committed
feat: Implement strict validation for client while loading tools
1 parent e5763b9 commit 9725c44

File tree

1 file changed

+91
-10
lines changed
  • packages/toolbox-core/src/toolbox_core

1 file changed

+91
-10
lines changed

packages/toolbox-core/src/toolbox_core/client.py

Lines changed: 91 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,9 @@ async def load_tool(
157157
for execution. The specific arguments and behavior of the callable
158158
depend on the tool itself.
159159
160+
Raises:
161+
ValueError: If the loaded tool instance fails to utilize at least
162+
one provided parameter or auth token (if any provided).
160163
"""
161164
# Resolve client headers
162165
resolved_headers = {
@@ -174,55 +177,133 @@ async def load_tool(
174177
if name not in manifest.tools:
175178
# TODO: Better exception
176179
raise Exception(f"Tool '{name}' not found!")
177-
tool, _, _ = self.__parse_tool(
180+
tool, used_auth_keys, used_bound_keys = self.__parse_tool(
178181
name,
179182
manifest.tools[name],
180183
auth_token_getters,
181184
bound_params,
182185
self.__client_headers,
183186
)
184187

188+
provided_auth_keys = set(auth_token_getters.keys())
189+
provided_bound_keys = set(bound_params.keys())
190+
191+
unused_auth = provided_auth_keys - used_auth_keys
192+
unused_bound = provided_bound_keys - used_bound_keys
193+
194+
if unused_auth or unused_bound:
195+
error_messages = []
196+
if unused_auth:
197+
error_messages.append(
198+
f"unused auth tokens: {', '.join(unused_auth)}"
199+
)
200+
if unused_bound:
201+
error_messages.append(
202+
f"unused bound parameters: {', '.join(unused_bound)}"
203+
)
204+
raise ValueError(
205+
f"Validation failed for tool '{name}': { '; '.join(error_messages) }."
206+
)
207+
185208
return tool
186209

187210
async def load_toolset(
188211
self,
189212
name: Optional[str] = None,
190213
auth_token_getters: dict[str, Callable[[], str]] = {},
191214
bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {},
215+
strict: bool = False,
192216
) -> list[ToolboxTool]:
193217
"""
194218
Asynchronously fetches a toolset and loads all tools defined within it.
195219
196220
Args:
197-
name: Name of the toolset to load tools.
221+
name: Name of the toolset to load. If None, loads the default toolset.
198222
auth_token_getters: A mapping of authentication service names to
199223
callables that return the corresponding authentication token.
200224
bound_params: A mapping of parameter names to bind to specific values or
201225
callables that are called to produce values as needed.
226+
strict: If True, raises an error if *any* loaded tool instance fails
227+
to utilize at least one provided parameter or auth token (if any
228+
provided). If False (default), raises an error only if a
229+
user-provided parameter or auth token cannot be applied to *any*
230+
loaded tool across the set.
202231
203232
Returns:
204233
list[ToolboxTool]: A list of callables, one for each tool defined
205234
in the toolset.
235+
236+
Raises:
237+
ValueError: If validation fails based on the `strict` flag.
206238
"""
239+
207240
# Resolve client headers
208241
original_headers = self.__client_headers
209242
resolved_headers = {
210243
header_name: await resolve_value(original_headers[header_name])
211244
for header_name in original_headers
212245
}
213-
# Request the definition of the tool from the server
246+
# Request the definition of the toolset from the server
214247
url = f"{self.__base_url}/api/toolset/{name or ''}"
215248
async with self.__session.get(url, headers=resolved_headers) as response:
216249
json = await response.json()
217250
manifest: ManifestSchema = ManifestSchema(**json)
218251

219-
# parse each tools name and schema into a list of ToolboxTools
220-
tools = [
221-
self.__parse_tool(
222-
n, s, auth_token_getters, bound_params, self.__client_headers
223-
)[0]
224-
for n, s in manifest.tools.items()
225-
]
252+
tools: list[ToolboxTool] = []
253+
overall_used_auth_keys: set[str] = set()
254+
overall_used_bound_params: set[str] = set()
255+
provided_auth_keys = set(auth_token_getters.keys())
256+
provided_bound_keys = set(bound_params.keys())
257+
258+
# parse each tool's name and schema into a list of ToolboxTools
259+
for tool_name, schema in manifest.tools.items():
260+
tool, used_auth_keys, used_bound_keys = self.__parse_tool(
261+
tool_name,
262+
schema,
263+
auth_token_getters,
264+
bound_params,
265+
self.__client_headers,
266+
)
267+
tools.append(tool)
268+
269+
if strict:
270+
unused_auth = provided_auth_keys - used_auth_keys
271+
unused_bound = provided_bound_keys - used_bound_keys
272+
if unused_auth or unused_bound:
273+
error_messages = []
274+
if unused_auth:
275+
error_messages.append(
276+
f"unused auth tokens: {', '.join(unused_auth)}"
277+
)
278+
if unused_bound:
279+
error_messages.append(
280+
f"unused bound parameters: {', '.join(unused_bound)}"
281+
)
282+
raise ValueError(
283+
f"Validation failed for tool '{name}': { '; '.join(error_messages) }."
284+
)
285+
else:
286+
overall_used_auth_keys.update(used_auth_keys)
287+
overall_used_bound_params.update(used_bound_keys)
288+
289+
if not strict:
290+
unused_auth = provided_auth_keys - overall_used_auth_keys
291+
unused_bound = provided_bound_keys - overall_used_bound_params
292+
293+
if unused_auth or unused_bound:
294+
error_messages = []
295+
if unused_auth:
296+
error_messages.append(
297+
f"unused auth tokens could not be applied to any tool: {', '.join(unused_auth)}"
298+
)
299+
if unused_bound:
300+
error_messages.append(
301+
f"unused bound parameters could not be applied to any tool: {', '.join(unused_bound)}"
302+
)
303+
raise ValueError(
304+
f"Validation failed for toolset '{name or 'default'}': { '; '.join(error_messages) }."
305+
)
306+
226307
return tools
227308

228309
async def add_headers(

0 commit comments

Comments
 (0)