Skip to content

Commit 1cb4504

Browse files
upstream guardrails updates (#57)
1 parent d34e58d commit 1cb4504

File tree

17 files changed

+2018
-65
lines changed

17 files changed

+2018
-65
lines changed

invariant/analyzer/language/scope.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"tuple",
3333
"tool_call",
3434
"text",
35+
"server",
3536
"image",
3637
"Tool",
3738
"ToolParameter",

invariant/analyzer/runtime/evaluation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ async def process_input(input_dict):
339339
print()
340340

341341
# track number of rule body evaluations
342-
evaluation_context.evaluation_counter += 1
342+
evaluation_context.increment_evaluation_counter()
343343

344344
result, new_variable_domains, ranges = await Interpreter.eval(
345345
expr_or_list,
@@ -418,6 +418,7 @@ def __init__(self, variable_store, globals, evaluation_context=None, partial=Tru
418418

419419
self.ranges = []
420420

421+
# output stream for printing
421422
self.output_stream = sys.stdout
422423

423424
# variable ranges describe the domain of all encountered

invariant/analyzer/runtime/evaluation_context.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,40 @@
44
In a separate file, for better separation of dependencies.
55
"""
66

7+
import os
78
from typing import Optional, TypeVar
89

10+
from invariant.analyzer.language.ast import Node
911
from invariant.analyzer.runtime.input import Input
1012
from invariant.analyzer.runtime.symbol_table import SymbolTable
11-
from invariant.analyzer.language.ast import Node
1213

1314
R = TypeVar("R")
1415

16+
1517
class EvaluationContext:
1618
"""
1719
An evaluation context enables a caller to handle the
1820
evaluation of external functions explicitly (e.g. for caching)
1921
and provide their own flow semantics (e.g. lookup in a graph).
2022
"""
2123

22-
def __init__(self, symbol_table: Optional[SymbolTable] = None):
24+
def __init__(
25+
self, maximum_iterations: int | None = None, symbol_table: Optional[SymbolTable] = None
26+
):
2327
self.symbol_table = symbol_table
2428

2529
self.evaluation_counter = 0
30+
self.maximum_iterations = int(os.environ.get("INVARIANT_MAX_ITERATIONS", 100))
31+
32+
def increment_evaluation_counter(self):
33+
self.evaluation_counter += 1
34+
if (
35+
self.maximum_iterations is not None
36+
and self.evaluation_counter > self.maximum_iterations
37+
):
38+
raise RuntimeError(
39+
f"Maximum checking cycles exceeded: {self.evaluation_counter - 1} (please contact Invariant support if you need to increase this limit, or rewrite your rule to be less expensive)"
40+
)
2641

2742
def call_function(self, function, args, **kwargs):
2843
raise NotImplementedError("EvaluationContext must implement call_function()")

invariant/analyzer/runtime/function_cache.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,36 @@
1+
import asyncio
12
import inspect
23
from typing import Awaitable, Callable, ParamSpec, TypeVar
34

45
P = ParamSpec("P")
56
R = TypeVar("R")
67

78

9+
class DictLock:
10+
def __init__(self):
11+
self.locks = {}
12+
self.main_lock = asyncio.Lock()
13+
14+
async def key(self, key):
15+
async with self.main_lock:
16+
if key not in self.locks:
17+
self.locks[key] = DictLockValue(key)
18+
return self.locks[key]
19+
20+
21+
class DictLockValue:
22+
def __init__(self, value):
23+
self.value = value
24+
self.lock = asyncio.Lock()
25+
26+
async def __aenter__(self):
27+
await self.lock.acquire()
28+
29+
async def __aexit__(self, exc_type, exc_value, traceback):
30+
self.lock.release()
31+
return False
32+
33+
834
class FunctionCache:
935
"""
1036
The function cache is responsible for handling function calls in policy rules
@@ -16,13 +42,14 @@ class FunctionCache:
1642

1743
def __init__(self, cache=None):
1844
self.cache = cache or {}
45+
self.cache_locks = DictLock()
1946

2047
def clear(self):
2148
self.cache = {}
2249

2350
def arg_key(self, arg):
2451
# cache primitives by value
25-
if type(arg) is int or type(arg) is float or type(arg) is str:
52+
if type(arg) is int or type(arg) is float or type(arg) is str or type(arg) is bool:
2653
return arg
2754
# cache lists by id
2855
elif type(arg) is list:
@@ -56,16 +83,17 @@ async def acall(
5683
# if function is not marked with @cached we just call it directly (see ./functions.py module)
5784
if not hasattr(function, "__invariant_cache__"):
5885
return await call_either_way(function, *args, **kwargs)
59-
6086
key = self.call_key(function, args, kwargs)
61-
value = await self.get(key)
62-
63-
if value is not None:
64-
return value
65-
else:
66-
value = await call_either_way(function, *args, **kwargs)
67-
await self.set(key, value)
68-
return value
87+
88+
async with await self.cache_locks.key(key):
89+
value = await self.get(key)
90+
91+
if value is not None:
92+
return value
93+
else:
94+
value = await call_either_way(function, *args, **kwargs)
95+
await self.set(key, value)
96+
return value
6997

7098

7199
async def call_either_way(

invariant/analyzer/runtime/input.py

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import warnings
1212
from collections.abc import ItemsView, KeysView, ValuesView
1313
from copy import deepcopy
14-
from typing import Callable, List, Optional
14+
from typing import Callable, List, Optional, get_args
1515

1616
from pydantic import BaseModel
1717
from rich.pretty import pprint as rich_print
@@ -27,6 +27,7 @@
2727
ToolCall,
2828
ToolOutput,
2929
ToolParameter,
30+
ToolParameterType,
3031
)
3132
from invariant.analyzer.runtime.runtime_errors import InvariantInputValidationError
3233

@@ -381,6 +382,12 @@ def parse_input(self, input: list[dict]) -> list[Event]:
381382
if not isinstance(event, dict):
382383
parsed_data.append(event)
383384
continue
385+
386+
# Extract relevant metadata from the event
387+
if "metadata" in event:
388+
if "server" in event["metadata"]:
389+
event["server"] = event["metadata"]["server"]
390+
384391
if "role" in event:
385392
if event["role"] != "tool":
386393
# if arguments are given as string convert them into dict using json.loads(...)
@@ -394,7 +401,6 @@ def parse_input(self, input: list[dict]) -> list[Event]:
394401
# # convert .content str to [{"type": "text": <content>}]
395402
# if type(event.get("content")) is str:
396403
# event["content"] = [{"type": "text", "text": event["content"]}]
397-
398404
msg = Message(**event)
399405
parsed_data.append(msg)
400406
if msg.tool_calls is not None:
@@ -420,22 +426,24 @@ def parse_tool_param(
420426
name: str, schema: dict, required_keys: Optional[List[str]] = None
421427
) -> ToolParameter:
422428
param_type = schema.get("type", "string")
423-
description = schema.get("description", "")
429+
description = schema.get("description")
430+
if description is None:
431+
description = "no description available"
424432

425433
# Only object-level schemas have required fields as a list
426434
if required_keys is None:
427435
required_keys = schema.get("required", [])
428436

429437
aliases = {
430-
"integer": "number",
431-
"int": "number",
438+
"int": "integer",
439+
"long": "integer",
432440
"float": "number",
433441
"bool": "boolean",
434442
"str": "string",
435443
"dict": "object",
436444
"list": "array",
437445
}
438-
if param_type in aliases:
446+
if isinstance(param_type, str) and param_type in aliases:
439447
param_type = aliases[param_type]
440448

441449
if param_type == "object":
@@ -446,30 +454,56 @@ def parse_tool_param(
446454
schema=subschema,
447455
required_keys=schema.get("required", []),
448456
)
457+
additional_properties = (
458+
bool(schema.get("additionalProperties"))
459+
if schema.get("additionalProperties") is not None
460+
else None
461+
)
462+
449463
return ToolParameter(
450464
name=name,
451465
type="object",
452466
description=description,
453467
required=name in required_keys,
454468
properties=properties,
455-
additionalProperties=schema.get("additionalProperties"),
469+
additionalProperties=additional_properties,
456470
)
457471
elif param_type == "array":
458472
return ToolParameter(
459473
name=name,
460474
type="array",
461475
description=description,
462476
required=name in required_keys,
463-
items=parse_tool_param(name=f"{name} item", schema=schema["items"]),
477+
items=parse_tool_param(
478+
name=f"{name} item", schema=schema["items"], required_keys=[]
479+
)
480+
if "items" in schema
481+
else None,
464482
)
465-
elif param_type in ["object", "array", "string", "number", "boolean"]:
483+
elif param_type in ["string", "number", "integer", "boolean"]:
466484
return ToolParameter(
467485
name=name,
468486
type=param_type,
469487
description=description,
470488
required=name in required_keys,
471489
enum=schema.get("enum"),
472490
)
491+
elif isinstance(param_type, list):
492+
required = name in required_keys
493+
for param in param_type:
494+
if "null" in param:
495+
required = False
496+
continue
497+
if param not in get_args(ToolParameterType):
498+
raise InvariantInputValidationError(
499+
f"Unsupported schema type: {param} for parameter {name}. Supported types are: object, array, string, number, boolean."
500+
)
501+
return ToolParameter(
502+
name=name,
503+
type=[p for p in param_type if p != "null"],
504+
description=description,
505+
required=required,
506+
)
473507
else:
474508
raise InvariantInputValidationError(
475509
f"Unsupported schema type: {param_type} for parameter {name}. Supported types are: object, array, string, number, boolean."
@@ -488,10 +522,22 @@ def parse_tool_param(
488522
)
489523
)
490524

525+
tool_desc = tool.get("description")
526+
if tool_desc is None:
527+
tool_desc = "no description available"
528+
if not isinstance(tool_desc, str):
529+
try:
530+
tool_desc = str(tool_desc)
531+
except Exception:
532+
raise InvariantInputValidationError(
533+
f"Tool description should be a string. Instead, got: {tool_desc} of type {type(tool_desc)}"
534+
)
535+
server = tool["metadata"].get("server", None) if "metadata" in tool else None
491536
tool_obj = Tool(
492537
name=name,
493-
description=tool["description"],
538+
description=tool_desc,
494539
inputSchema=properties,
540+
server=server,
495541
)
496542
parsed_data.append(tool_obj)
497543
else:

invariant/analyzer/runtime/nodes.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ class Event(BaseModel):
1717
default_factory=dict, description="Metadata associated with the event"
1818
)
1919

20+
server: Optional[str] = None
21+
2022

2123
class Function(BaseModel):
2224
name: str
@@ -309,8 +311,18 @@ def image(*args) -> list[str]:
309311
return result
310312

311313

314+
ToolParameterType = Literal[
315+
"object",
316+
"array",
317+
"string",
318+
"number",
319+
"integer",
320+
"boolean",
321+
] # extend as needed
322+
323+
312324
class ToolParameter(BaseModel):
313-
type: Literal["object", "array", "string", "number", "boolean"] # extend as needed
325+
type: ToolParameterType | list[ToolParameterType]
314326
name: str
315327
description: str
316328
required: bool = False
@@ -351,7 +363,7 @@ class Tool(Event):
351363
inputSchema: list[ToolParameter]
352364

353365
def __invariant_attribute__(self, name: str):
354-
if name in ["name", "description", "inputSchema"]:
366+
if name in ["name", "description", "inputSchema", "server"]:
355367
return getattr(self, name)
356368
raise InvariantAttributeError(
357369
f"Attribute {name} not found in Tool. Available attributes are: name, description, inputSchema"

invariant/analyzer/runtime/rule.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import textwrap
3-
from typing import Optional, Callable, TypeVar, ParamSpec, Awaitable
3+
from typing import Awaitable, Callable, Optional, ParamSpec, TypeVar
44

55
import invariant.analyzer.language.ast as ast
66
from invariant.analyzer.language.linking import link
@@ -18,6 +18,7 @@
1818
P = ParamSpec("P")
1919
R = TypeVar("R")
2020

21+
2122
class PolicyAction:
2223
def __call__(self, input_dict):
2324
raise NotImplementedError()
@@ -154,14 +155,20 @@ def from_raise_policy(cls, policy: ast.RaisePolicy, globals):
154155

155156
class InputEvaluationContext(EvaluationContext):
156157
def __init__(
157-
self, input: Input, rule_set: "RuleSet", policy_parameters, symbol_table: Optional[SymbolTable]
158+
self,
159+
input: Input,
160+
rule_set: "RuleSet",
161+
policy_parameters,
162+
symbol_table: Optional[SymbolTable],
158163
):
159164
super().__init__(symbol_table=symbol_table)
160165
self.input = input
161166
self.rule_set = rule_set
162167
self.policy_parameters = policy_parameters
163168

164-
async def acall_function(self, func: Callable[P, Awaitable[R]] | Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
169+
async def acall_function(
170+
self, func: Callable[P, Awaitable[R]] | Callable[P, R], *args: P.args, **kwargs: P.kwargs
171+
) -> R:
165172
return await self.rule_set.acall_function(func, *args, **kwargs)
166173

167174
def has_flow(self, a, b):
@@ -211,7 +218,9 @@ def instance(self, cache: Optional[FunctionCache] = None):
211218
def __del__(self):
212219
self.function_cache.clear()
213220

214-
async def acall_function(self, func: Callable[P, Awaitable[R]] | Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
221+
async def acall_function(
222+
self, func: Callable[P, Awaitable[R]] | Callable[P, R], *args: P.args, **kwargs: P.kwargs
223+
) -> R:
215224
return await self.function_cache.acall(func, *args, **kwargs)
216225

217226
def log_apply(self, rule, model):
@@ -248,6 +257,7 @@ async def apply(self, input_data: Input, policy_parameters):
248257
exceptions.extend(error)
249258

250259
self.input = None
260+
251261
return exceptions
252262

253263
def __str__(self):

0 commit comments

Comments
 (0)