1111import warnings
1212from collections .abc import ItemsView , KeysView , ValuesView
1313from copy import deepcopy
14- from typing import Callable , List , Optional
14+ from typing import Callable , List , Optional , get_args
1515
1616from pydantic import BaseModel
1717from rich .pretty import pprint as rich_print
2727 ToolCall ,
2828 ToolOutput ,
2929 ToolParameter ,
30+ ToolParameterType ,
3031)
3132from invariant .analyzer .runtime .runtime_errors import InvariantInputValidationError
3233
@@ -381,6 +382,12 @@ def parse_input(self, input: list[dict]) -> list[Event]:
381382 if not isinstance (event , dict ):
382383 parsed_data .append (event )
383384 continue
385+
386+ # Extract relevant metadata from the event
387+ if "metadata" in event :
388+ if "server" in event ["metadata" ]:
389+ event ["server" ] = event ["metadata" ]["server" ]
390+
384391 if "role" in event :
385392 if event ["role" ] != "tool" :
386393 # if arguments are given as string convert them into dict using json.loads(...)
@@ -394,7 +401,6 @@ def parse_input(self, input: list[dict]) -> list[Event]:
394401 # # convert .content str to [{"type": "text": <content>}]
395402 # if type(event.get("content")) is str:
396403 # event["content"] = [{"type": "text", "text": event["content"]}]
397-
398404 msg = Message (** event )
399405 parsed_data .append (msg )
400406 if msg .tool_calls is not None :
@@ -420,22 +426,24 @@ def parse_tool_param(
420426 name : str , schema : dict , required_keys : Optional [List [str ]] = None
421427 ) -> ToolParameter :
422428 param_type = schema .get ("type" , "string" )
423- description = schema .get ("description" , "" )
429+ description = schema .get ("description" )
430+ if description is None :
431+ description = "no description available"
424432
425433 # Only object-level schemas have required fields as a list
426434 if required_keys is None :
427435 required_keys = schema .get ("required" , [])
428436
429437 aliases = {
430- "integer " : "number " ,
431- "int " : "number " ,
438+ "int " : "integer " ,
439+ "long " : "integer " ,
432440 "float" : "number" ,
433441 "bool" : "boolean" ,
434442 "str" : "string" ,
435443 "dict" : "object" ,
436444 "list" : "array" ,
437445 }
438- if param_type in aliases :
446+ if isinstance ( param_type , str ) and param_type in aliases :
439447 param_type = aliases [param_type ]
440448
441449 if param_type == "object" :
@@ -446,30 +454,56 @@ def parse_tool_param(
446454 schema = subschema ,
447455 required_keys = schema .get ("required" , []),
448456 )
457+ additional_properties = (
458+ bool (schema .get ("additionalProperties" ))
459+ if schema .get ("additionalProperties" ) is not None
460+ else None
461+ )
462+
449463 return ToolParameter (
450464 name = name ,
451465 type = "object" ,
452466 description = description ,
453467 required = name in required_keys ,
454468 properties = properties ,
455- additionalProperties = schema . get ( "additionalProperties" ) ,
469+ additionalProperties = additional_properties ,
456470 )
457471 elif param_type == "array" :
458472 return ToolParameter (
459473 name = name ,
460474 type = "array" ,
461475 description = description ,
462476 required = name in required_keys ,
463- items = parse_tool_param (name = f"{ name } item" , schema = schema ["items" ]),
477+ items = parse_tool_param (
478+ name = f"{ name } item" , schema = schema ["items" ], required_keys = []
479+ )
480+ if "items" in schema
481+ else None ,
464482 )
465- elif param_type in ["object " , "array " , "string" , "number " , "boolean" ]:
483+ elif param_type in ["string " , "number " , "integer " , "boolean" ]:
466484 return ToolParameter (
467485 name = name ,
468486 type = param_type ,
469487 description = description ,
470488 required = name in required_keys ,
471489 enum = schema .get ("enum" ),
472490 )
491+ elif isinstance (param_type , list ):
492+ required = name in required_keys
493+ for param in param_type :
494+ if "null" in param :
495+ required = False
496+ continue
497+ if param not in get_args (ToolParameterType ):
498+ raise InvariantInputValidationError (
499+ f"Unsupported schema type: { param } for parameter { name } . Supported types are: object, array, string, number, boolean."
500+ )
501+ return ToolParameter (
502+ name = name ,
503+ type = [p for p in param_type if p != "null" ],
504+ description = description ,
505+ required = required ,
506+ )
473507 else :
474508 raise InvariantInputValidationError (
475509 f"Unsupported schema type: { param_type } for parameter { name } . Supported types are: object, array, string, number, boolean."
@@ -488,10 +522,22 @@ def parse_tool_param(
488522 )
489523 )
490524
525+ tool_desc = tool .get ("description" )
526+ if tool_desc is None :
527+ tool_desc = "no description available"
528+ if not isinstance (tool_desc , str ):
529+ try :
530+ tool_desc = str (tool_desc )
531+ except Exception :
532+ raise InvariantInputValidationError (
533+ f"Tool description should be a string. Instead, got: { tool_desc } of type { type (tool_desc )} "
534+ )
535+ server = tool ["metadata" ].get ("server" , None ) if "metadata" in tool else None
491536 tool_obj = Tool (
492537 name = name ,
493- description = tool [ "description" ] ,
538+ description = tool_desc ,
494539 inputSchema = properties ,
540+ server = server ,
495541 )
496542 parsed_data .append (tool_obj )
497543 else :
0 commit comments