Skip to content

Commit 3609750

Browse files
authored
upstream guardrails updates (#50)
1 parent 123104c commit 3609750

File tree

25 files changed

+478
-79
lines changed

25 files changed

+478
-79
lines changed

invariant/analyzer/examples/agent_bugs/traceset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,9 @@ def analyze(self, policy):
115115
def filter(
116116
self,
117117
invariant_condition: str,
118-
max_items: int = None,
119-
python: str = None,
120-
prefix: str = None,
118+
max_items: int | None = None,
119+
python: str | None = None,
120+
prefix: str | None = None,
121121
) -> "TraceSet":
122122
max_items = self.get_max_items(max_items)
123123

@@ -155,7 +155,7 @@ def get_max_items(self, max_items):
155155

156156
return max_items
157157

158-
def prepare_policy(self, invariant_condition: str, prefix: str = None):
158+
def prepare_policy(self, invariant_condition: str, prefix: str | None = None):
159159
from invariant.analyzer import Policy
160160

161161
# construct makeshift policy

invariant/analyzer/examples/agent_flan/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def has_bash(conv):
5252
else:
5353
tool_called = True
5454
match = re.search(r"```bash\n(.+?)\n```", msg["content"], re.DOTALL)
55-
bash_cmd = match.group(1)
55+
bash_cmd = match.group(1) if match else None
5656
# print("-> ", bash_cmd)
5757
ipl_messages += [
5858
{

invariant/analyzer/language/scope.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
"tool_call",
3434
"text",
3535
"image",
36+
"Tool",
37+
"ToolParameter",
3638
]
3739

3840

invariant/analyzer/language/typing.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,15 +158,6 @@ def visit_MemberAccess(self, node: MemberAccess):
158158
node.type = UnknownType()
159159
return node
160160

161-
def has_member(obj_type, member):
162-
if isinstance(obj_type, type) and hasattr(obj_type, "__dataclass_fields__"):
163-
# check on type reference if attribute exists
164-
return member in obj_type.__dataclass_fields__
165-
return False
166-
167-
if not has_member(node.expr.type, node.member):
168-
raise PolicyError(f"Type {node.expr}: {node.expr.type} has no member {node.member}")
169-
170161
node.type = UnknownType()
171162

172163
return node
@@ -245,7 +236,7 @@ def visit_ListComprehension(self, node: ListComprehension):
245236
node.scope.parent = self.context.scope
246237

247238
# Add the iteration variable to the scope
248-
var_name = node.var_name.name if hasattr(node.var_name, 'name') else node.var_name
239+
var_name = node.var_name.name if hasattr(node.var_name, "name") else node.var_name
249240
var_decl = VariableDeclaration(var_name, UnknownType())
250241
node.scope.declarations = {var_name: var_decl}
251242

invariant/analyzer/policy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def errors(self):
7070
return self.policy_root.errors
7171

7272
@classmethod
73-
def from_file(cls, path: str) -> "Policy":
73+
def from_file(cls, path: str) -> "LocalPolicy":
7474
return cls(parse_file(path))
7575

7676
@classmethod
@@ -80,14 +80,14 @@ def from_string(
8080
path: str | None = None,
8181
optimize: bool = False,
8282
symbol_table: Optional[SymbolTable] = None,
83-
) -> "Policy":
83+
) -> "LocalPolicy":
8484
return cls(parse(string, path, optimize_rules=optimize), symbol_table=symbol_table)
8585

8686
def add_error_to_result(self, error, analysis_result: AnalysisResult):
8787
"""Implements how errors are added to an analysis result (e.g. as handled or non-handled errors)."""
8888
analysis_result.errors.append(error)
8989

90-
def analyze(self, input: dict, raise_unhandled=False, **policy_parameters):
90+
def analyze(self, input: list[dict], raise_unhandled=False, **policy_parameters):
9191
return asyncio.run(self.a_analyze(input, raise_unhandled, **policy_parameters))
9292

9393
async def a_analyze(

invariant/analyzer/remote_policy.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from invariant.analyzer.runtime.runtime_errors import (
1010
ExcessivePolicyError,
1111
InvariantAttributeError,
12+
InvariantInputValidationError,
1213
MissingPolicyParameter,
1314
PolicyExecutionError,
1415
)
@@ -111,7 +112,7 @@ async def a_preload(self):
111112
if "errors" in result:
112113
raise ValueError("Invalid response from policy service: " + str(result))
113114

114-
def analyze(self, input: dict, raise_unhandled=False, **policy_parameters):
115+
def analyze(self, input: list[dict], raise_unhandled=False, **policy_parameters):
115116
return asyncio.run(self.a_analyze(input, raise_unhandled, **policy_parameters))
116117

117118
def get_json_policy_parameters(self, **policy_parameters):
@@ -261,6 +262,16 @@ async def handle_error_response(response: aiohttp.ClientResponse):
261262
except Exception as e:
262263
text = str(e)
263264
raise InvariantAttributeError(text)
265+
elif InvariantInputValidationError.catchphrase in text:
266+
# get 'detail' from json
267+
try:
268+
error_json = await response.json()
269+
text = error_json.get("detail", str(error_json))
270+
# remove 'Invariant attribute error' from text
271+
text = text.replace(InvariantInputValidationError.catchphrase, "")
272+
except Exception as e:
273+
text = str(e)
274+
raise InvariantInputValidationError(text)
264275
else:
265276
raise ValueError("Invalid response from policy service: " + text)
266277

invariant/analyzer/runtime/evaluation.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
import sys
66
from dataclasses import dataclass # noqa: I001
77
from itertools import product
8-
from typing import AsyncGenerator
8+
from typing import AsyncGenerator, Awaitable, Callable, ParamSpec, TypeVar
99

1010
import termcolor
1111

12+
P = ParamSpec("P")
13+
R = TypeVar("R")
1214
from invariant.analyzer.language.ast import (
1315
ArrayLiteral,
1416
BinaryExpr,
@@ -786,10 +788,6 @@ def _is_unknown(self, value):
786788
async def visit_FunctionCall(self, node: FunctionCall):
787789
function = await self.visit(node.name)
788790
args = await asyncio.gather(*[self.visit(arg) for arg in node.args])
789-
790-
# only call functions, once all parameters are known
791-
if function is Unknown or any(self._is_unknown(arg) for arg in args):
792-
return Unknown
793791
kwarg_items = await asyncio.gather(
794792
*[
795793
asyncio.gather(self.visit(entry.key), self.visit(entry.value))
@@ -798,6 +796,14 @@ async def visit_FunctionCall(self, node: FunctionCall):
798796
)
799797
kwargs = {k: v for k, v in kwarg_items}
800798

799+
# only call functions, once all parameters are known
800+
if (
801+
function is Unknown
802+
or any(self._is_unknown(arg) for arg in args)
803+
or any(self._is_unknown(v) for v in kwargs.values())
804+
):
805+
return Unknown
806+
801807
if isinstance(function, Declaration):
802808
return await self.visit_PredicateCall(function, args, **kwargs)
803809
else:
@@ -932,6 +938,9 @@ async def visit_ListComprehension(self, node: ListComprehension):
932938
if iterable is None:
933939
return []
934940

941+
if iterable is Unknown:
942+
return Unknown
943+
935944
var_name = node.var_name.id if hasattr(node.var_name, "id") else node.var_name
936945
results = []
937946
original_vars = self.variable_store.copy()
@@ -953,14 +962,21 @@ async def visit_ListComprehension(self, node: ListComprehension):
953962
self.variable_store = original_vars
954963
return results
955964

956-
async def acall_function(self, function, *args, **kwargs):
965+
async def acall_function(
966+
self,
967+
function: Callable[P, Awaitable[R]] | Callable[P, R] | CachedFunctionWrapper,
968+
*args: P.args,
969+
**kwargs: P.kwargs,
970+
) -> R:
957971
ctx: EvaluationContext = self.evaluation_context
958972
# if function is a cached function wrapper, unwrap it (at this point we
959973
# are already caching it)
960-
if isinstance(function, CachedFunctionWrapper):
961-
function = function.func
962-
linked_function = ctx.link(function, None)
974+
func = function.func if isinstance(function, CachedFunctionWrapper) else function
975+
linked_function = ctx.link(func, None)
963976
# also unwrap linked function, if a cached function wrapper
964-
if isinstance(linked_function, CachedFunctionWrapper):
965-
linked_function = linked_function.func
966-
return await ctx.acall_function(linked_function, args, **kwargs)
977+
linked_function = (
978+
linked_function.func
979+
if isinstance(linked_function, CachedFunctionWrapper)
980+
else linked_function
981+
)
982+
return await ctx.acall_function(linked_function, *args, **kwargs)

invariant/analyzer/runtime/evaluation_context.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
In a separate file, for better separation of dependencies.
55
"""
66

7-
from typing import Optional
7+
from typing import Optional, TypeVar
88

99
from invariant.analyzer.runtime.input import Input
1010
from invariant.analyzer.runtime.symbol_table import SymbolTable
11+
from invariant.analyzer.language.ast import Node
1112

13+
R = TypeVar("R")
1214

1315
class EvaluationContext:
1416
"""
@@ -28,7 +30,7 @@ def call_function(self, function, args, **kwargs):
2830
async def acall_function(self, function, args, **kwargs):
2931
raise NotImplementedError("EvaluationContext must implement acall_function()")
3032

31-
def link(self, function, node):
33+
def link(self, function: R, node: Node | None) -> R:
3234
if self.symbol_table:
3335
return self.symbol_table.link(function, node)
3436
else:

invariant/analyzer/runtime/function_cache.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
import inspect
2+
from typing import Awaitable, Callable, ParamSpec, TypeVar
3+
4+
P = ParamSpec("P")
5+
R = TypeVar("R")
26

37

48
class FunctionCache:
@@ -43,7 +47,12 @@ async def get(self, key):
4347
async def set(self, key, value):
4448
self.cache[key] = value
4549

46-
async def acall(self, function, args, **kwargs):
50+
async def acall(
51+
self,
52+
function: Callable[P, Awaitable[R]] | Callable[P, R],
53+
*args: P.args,
54+
**kwargs: P.kwargs,
55+
) -> R:
4756
# if function is not marked with @cached we just call it directly (see ./functions.py module)
4857
if not hasattr(function, "__invariant_cache__"):
4958
return await call_either_way(function, *args, **kwargs)
@@ -59,8 +68,11 @@ async def acall(self, function, args, **kwargs):
5968
return value
6069

6170

62-
async def call_either_way(fct, *args, **kwargs):
63-
if inspect.iscoroutinefunction(fct):
64-
return await fct(*args, **kwargs)
71+
async def call_either_way(
72+
func: Callable[P, Awaitable[R]] | Callable[P, R], *args: P.args, **kwargs: P.kwargs
73+
) -> R:
74+
if inspect.iscoroutinefunction(func):
75+
return await func(*args, **kwargs)
6576
else:
66-
return fct(*args, **kwargs)
77+
print([func, args, kwargs], flush=True)
78+
return func(*args, **kwargs) # type: ignore

invariant/analyzer/runtime/functions.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33
with special runtime attributes, relevant in the context of the
44
invariant agent analyzer.
55
"""
6+
from typing import Callable, TypeVar, ParamSpec, Awaitable, Generic
67

8+
P = ParamSpec("P")
9+
R = TypeVar("R")
710

8-
def cached(func):
11+
def cached(func: Callable[P, Awaitable[R]] | Callable[P, R]) -> Callable[P, Awaitable[R]]:
912
"""
1013
Decorator to mark a guardrailing function or built-in predicate as interpreter-cached.
1114
@@ -38,19 +41,18 @@ def guardrailing_function(...):
3841

3942
return CachedFunctionWrapper(func)
4043

41-
42-
class CachedFunctionWrapper:
44+
class CachedFunctionWrapper(Generic[P, R]):
4345
"""
4446
Wraps a function such that is is always called via the current Interpreter instance.
4547
4648
This enables caching and other runtime features like function re-linking in a server context.
4749
"""
4850

49-
def __init__(self, func):
51+
def __init__(self, func: Callable[P, Awaitable[R]] | Callable[P, R]):
5052
self.func = func
51-
self.func.__invariant_cache__ = True
53+
self.func.__invariant_cache__ = True # type: ignore
5254

53-
def __call__(self, *args, **kwargs):
55+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Awaitable[R]:
5456
from invariant.analyzer.runtime.evaluation import Interpreter
5557

5658
return Interpreter.current().acall_function(self.func, *args, **kwargs)

0 commit comments

Comments
 (0)