Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions invariant/analyzer/runtime/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,15 @@ def parse_tool_param(

for trace_idx, event in enumerate(parsed_data):
event.metadata["trace_idx"] = trace_idx

if (
hasattr(event, "tool_calls")
and event.tool_calls
and isinstance(event.tool_calls, list)
):
for tool_call in event.tool_calls:
tool_call.metadata["trace_idx"] = trace_idx

return parsed_data

def has_flow(self, a, b):
Expand Down
35 changes: 35 additions & 0 deletions invariant/tests/analyzer/test_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,41 @@ def test_analyze_pending(self):
self.assertTrue("Hello A!" in str(res.errors[0]))
self.assertTrue("Bye A!" in str(res.errors[1]))

def test_analyze_pending_detects_tool_calls(self):
"""Make sure that tool calls can raise when using analyze_pending"""

error_message = "dummy_tool should not be called"

policy = Monitor.from_string(f"""
raise PolicyViolation("{error_message}") if:
(call: ToolCall)
call is tool:dummy_tool
""")

messages = [
{
"role": "user",
"content": "whatever",
},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_5",
"type": "function",
"function": {"name": "dummy_tool", "arguments": {"name": "123"}},
}
],
},
]

res = policy.analyze_pending(messages[:-1], [messages[-1]])

self.assertEqual(len(res.errors), 1)
self.assertIsInstance(res.errors[0], ErrorInformation)
self.assertTrue(error_message in str(res.errors[0]))


if __name__ == "__main__":
unittest.main()
Loading