Skip to content

Commit c7eebbb

Browse files
authored
feat: Implement strict validation for client while loading tools (#205)
* fix: Add the no parameter check back again. We will remove this once we actually implement the `strict` flag and centralize this functionality by moving this check to the tool's constructor in a future PR. * fix: Reverse the error conditions to avoid masking of the second error. * feat: Implement strict validation for client while loading tools * fix: Fix adding tool's name for strict while loading toolset. * chore: Add unit test cases. * chore: Delint * chore: Fix integration tests. * chore: Delint * chore: Remove unnecessary if statement
1 parent d0b5020 commit c7eebbb

File tree

4 files changed

+652
-28
lines changed

4 files changed

+652
-28
lines changed

packages/toolbox-core/src/toolbox_core/client.py

Lines changed: 88 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,9 @@ async def load_tool(
157157
for execution. The specific arguments and behavior of the callable
158158
depend on the tool itself.
159159
160+
Raises:
161+
ValueError: If the loaded tool instance fails to utilize at least
162+
one provided parameter or auth token (if any provided).
160163
"""
161164
# Resolve client headers
162165
resolved_headers = {
@@ -174,55 +177,130 @@ async def load_tool(
174177
if name not in manifest.tools:
175178
# TODO: Better exception
176179
raise Exception(f"Tool '{name}' not found!")
177-
tool, _, _ = self.__parse_tool(
180+
tool, used_auth_keys, used_bound_keys = self.__parse_tool(
178181
name,
179182
manifest.tools[name],
180183
auth_token_getters,
181184
bound_params,
182185
self.__client_headers,
183186
)
184187

188+
provided_auth_keys = set(auth_token_getters.keys())
189+
provided_bound_keys = set(bound_params.keys())
190+
191+
unused_auth = provided_auth_keys - used_auth_keys
192+
unused_bound = provided_bound_keys - used_bound_keys
193+
194+
if unused_auth or unused_bound:
195+
error_messages = []
196+
if unused_auth:
197+
error_messages.append(f"unused auth tokens: {', '.join(unused_auth)}")
198+
if unused_bound:
199+
error_messages.append(
200+
f"unused bound parameters: {', '.join(unused_bound)}"
201+
)
202+
raise ValueError(
203+
f"Validation failed for tool '{name}': { '; '.join(error_messages) }."
204+
)
205+
185206
return tool
186207

187208
async def load_toolset(
188209
self,
189210
name: Optional[str] = None,
190211
auth_token_getters: dict[str, Callable[[], str]] = {},
191212
bound_params: Mapping[str, Union[Callable[[], Any], Any]] = {},
213+
strict: bool = False,
192214
) -> list[ToolboxTool]:
193215
"""
194216
Asynchronously fetches a toolset and loads all tools defined within it.
195217
196218
Args:
197-
name: Name of the toolset to load tools.
219+
name: Name of the toolset to load. If None, loads the default toolset.
198220
auth_token_getters: A mapping of authentication service names to
199221
callables that return the corresponding authentication token.
200222
bound_params: A mapping of parameter names to bind to specific values or
201223
callables that are called to produce values as needed.
224+
strict: If True, raises an error if *any* loaded tool instance fails
225+
to utilize at least one provided parameter or auth token (if any
226+
provided). If False (default), raises an error only if a
227+
user-provided parameter or auth token cannot be applied to *any*
228+
loaded tool across the set.
202229
203230
Returns:
204231
list[ToolboxTool]: A list of callables, one for each tool defined
205232
in the toolset.
233+
234+
Raises:
235+
ValueError: If validation fails based on the `strict` flag.
206236
"""
237+
207238
# Resolve client headers
208239
original_headers = self.__client_headers
209240
resolved_headers = {
210241
header_name: await resolve_value(original_headers[header_name])
211242
for header_name in original_headers
212243
}
213-
# Request the definition of the tool from the server
244+
# Request the definition of the toolset from the server
214245
url = f"{self.__base_url}/api/toolset/{name or ''}"
215246
async with self.__session.get(url, headers=resolved_headers) as response:
216247
json = await response.json()
217248
manifest: ManifestSchema = ManifestSchema(**json)
218249

219-
# parse each tools name and schema into a list of ToolboxTools
220-
tools = [
221-
self.__parse_tool(
222-
n, s, auth_token_getters, bound_params, self.__client_headers
223-
)[0]
224-
for n, s in manifest.tools.items()
225-
]
250+
tools: list[ToolboxTool] = []
251+
overall_used_auth_keys: set[str] = set()
252+
overall_used_bound_params: set[str] = set()
253+
provided_auth_keys = set(auth_token_getters.keys())
254+
provided_bound_keys = set(bound_params.keys())
255+
256+
# parse each tool's name and schema into a list of ToolboxTools
257+
for tool_name, schema in manifest.tools.items():
258+
tool, used_auth_keys, used_bound_keys = self.__parse_tool(
259+
tool_name,
260+
schema,
261+
auth_token_getters,
262+
bound_params,
263+
self.__client_headers,
264+
)
265+
tools.append(tool)
266+
267+
if strict:
268+
unused_auth = provided_auth_keys - used_auth_keys
269+
unused_bound = provided_bound_keys - used_bound_keys
270+
if unused_auth or unused_bound:
271+
error_messages = []
272+
if unused_auth:
273+
error_messages.append(
274+
f"unused auth tokens: {', '.join(unused_auth)}"
275+
)
276+
if unused_bound:
277+
error_messages.append(
278+
f"unused bound parameters: {', '.join(unused_bound)}"
279+
)
280+
raise ValueError(
281+
f"Validation failed for tool '{tool_name}': { '; '.join(error_messages) }."
282+
)
283+
else:
284+
overall_used_auth_keys.update(used_auth_keys)
285+
overall_used_bound_params.update(used_bound_keys)
286+
287+
unused_auth = provided_auth_keys - overall_used_auth_keys
288+
unused_bound = provided_bound_keys - overall_used_bound_params
289+
290+
if unused_auth or unused_bound:
291+
error_messages = []
292+
if unused_auth:
293+
error_messages.append(
294+
f"unused auth tokens could not be applied to any tool: {', '.join(unused_auth)}"
295+
)
296+
if unused_bound:
297+
error_messages.append(
298+
f"unused bound parameters could not be applied to any tool: {', '.join(unused_bound)}"
299+
)
300+
raise ValueError(
301+
f"Validation failed for toolset '{name or 'default'}': { '; '.join(error_messages) }."
302+
)
303+
226304
return tools
227305

228306
async def add_headers(

0 commit comments

Comments
 (0)