diff --git a/invariant/analyzer/language/scope.py b/invariant/analyzer/language/scope.py index 5f10453..f1584e7 100644 --- a/invariant/analyzer/language/scope.py +++ b/invariant/analyzer/language/scope.py @@ -30,6 +30,7 @@ "sum", "print", "tuple", + "tool_call", "text", "image", ] diff --git a/invariant/analyzer/stdlib/invariant/builtins.py b/invariant/analyzer/stdlib/invariant/builtins.py index 7012138..141150c 100644 --- a/invariant/analyzer/stdlib/invariant/builtins.py +++ b/invariant/analyzer/stdlib/invariant/builtins.py @@ -91,3 +91,19 @@ def tuple(*args, **kwargs): Creates a tuple from the given arguments. """ return py_builtins.tuple(*args, **kwargs) + + +def tool_call(tool_output: ToolOutput, *args, **kwargs) -> ToolCall: + """ + Gets the tool call object from a tool output. + + Args: + tool_output: A ToolOutput object. + + Returns: + The ToolCall object that corresponds to the tool call in the tool output. + """ + if not isinstance(tool_output, ToolOutput): + raise ValueError("tool_output argument must be a ToolOutput.") + + return tool_output._tool_call diff --git a/invariant/tests/analyzer/test_builtins.py b/invariant/tests/analyzer/test_builtins.py new file mode 100644 index 0000000..a334d65 --- /dev/null +++ b/invariant/tests/analyzer/test_builtins.py @@ -0,0 +1,36 @@ +import unittest + +from invariant.analyzer import Policy +from invariant.analyzer.traces import assistant, tool, tool_call, user + + +class TestBuiltins(unittest.TestCase): + def test_tool_call_name(self): + policy = Policy.from_string(""" + raise "error" if: + (tool_output: ToolOutput) + tool_call(tool_output).function.name == "some_tool" + """) + + trace = [ + user("What is the result of something?"), + assistant(None, tool_call("1", "some_tool", {})), + tool("1", "some_output"), + ] + result = policy.analyze(trace) + assert len(result.errors) == 1 + + def test_tool_call_name_no_match(self): + policy = Policy.from_string(""" + raise "error" if: + (tool_output: ToolOutput) + tool_call(tool_output).function.name == "some_tool" + """) + + trace = [ + user("What is the result of something?"), + assistant(None, tool_call("1", "some_other_tool", {})), + tool("1", "some_output"), + ] + result = policy.analyze(trace) + assert len(result.errors) == 0