Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions invariant/analyzer/examples/agent_bugs/traceset.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def analyze(self, policy):
def filter(
self,
invariant_condition: str,
max_items: int = None,
python: str = None,
prefix: str = None,
max_items: int | None = None,
python: str | None = None,
prefix: str | None = None,
) -> "TraceSet":
max_items = self.get_max_items(max_items)

Expand Down Expand Up @@ -155,7 +155,7 @@ def get_max_items(self, max_items):

return max_items

def prepare_policy(self, invariant_condition: str, prefix: str = None):
def prepare_policy(self, invariant_condition: str, prefix: str | None = None):
from invariant.analyzer import Policy

# construct makeshift policy
Expand Down
2 changes: 1 addition & 1 deletion invariant/analyzer/examples/agent_flan/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def has_bash(conv):
else:
tool_called = True
match = re.search(r"```bash\n(.+?)\n```", msg["content"], re.DOTALL)
bash_cmd = match.group(1)
bash_cmd = match.group(1) if match else None
# print("-> ", bash_cmd)
ipl_messages += [
{
Expand Down
2 changes: 2 additions & 0 deletions invariant/analyzer/language/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
"tool_call",
"text",
"image",
"Tool",
"ToolParameter",
]


Expand Down
11 changes: 1 addition & 10 deletions invariant/analyzer/language/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,6 @@ def visit_MemberAccess(self, node: MemberAccess):
node.type = UnknownType()
return node

def has_member(obj_type, member):
if isinstance(obj_type, type) and hasattr(obj_type, "__dataclass_fields__"):
# check on type reference if attribute exists
return member in obj_type.__dataclass_fields__
return False

if not has_member(node.expr.type, node.member):
raise PolicyError(f"Type {node.expr}: {node.expr.type} has no member {node.member}")

node.type = UnknownType()

return node
Expand Down Expand Up @@ -245,7 +236,7 @@ def visit_ListComprehension(self, node: ListComprehension):
node.scope.parent = self.context.scope

# Add the iteration variable to the scope
var_name = node.var_name.name if hasattr(node.var_name, 'name') else node.var_name
var_name = node.var_name.name if hasattr(node.var_name, "name") else node.var_name
var_decl = VariableDeclaration(var_name, UnknownType())
node.scope.declarations = {var_name: var_decl}

Expand Down
6 changes: 3 additions & 3 deletions invariant/analyzer/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def errors(self):
return self.policy_root.errors

@classmethod
def from_file(cls, path: str) -> "Policy":
def from_file(cls, path: str) -> "LocalPolicy":
return cls(parse_file(path))

@classmethod
Expand All @@ -80,14 +80,14 @@ def from_string(
path: str | None = None,
optimize: bool = False,
symbol_table: Optional[SymbolTable] = None,
) -> "Policy":
) -> "LocalPolicy":
return cls(parse(string, path, optimize_rules=optimize), symbol_table=symbol_table)

def add_error_to_result(self, error, analysis_result: AnalysisResult):
"""Implements how errors are added to an analysis result (e.g. as handled or non-handled errors)."""
analysis_result.errors.append(error)

def analyze(self, input: dict, raise_unhandled=False, **policy_parameters):
def analyze(self, input: list[dict], raise_unhandled=False, **policy_parameters):
return asyncio.run(self.a_analyze(input, raise_unhandled, **policy_parameters))

async def a_analyze(
Expand Down
13 changes: 12 additions & 1 deletion invariant/analyzer/remote_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from invariant.analyzer.runtime.runtime_errors import (
ExcessivePolicyError,
InvariantAttributeError,
InvariantInputValidationError,
MissingPolicyParameter,
PolicyExecutionError,
)
Expand Down Expand Up @@ -111,7 +112,7 @@ async def a_preload(self):
if "errors" in result:
raise ValueError("Invalid response from policy service: " + str(result))

def analyze(self, input: dict, raise_unhandled=False, **policy_parameters):
def analyze(self, input: list[dict], raise_unhandled=False, **policy_parameters):
return asyncio.run(self.a_analyze(input, raise_unhandled, **policy_parameters))

def get_json_policy_parameters(self, **policy_parameters):
Expand Down Expand Up @@ -261,6 +262,16 @@ async def handle_error_response(response: aiohttp.ClientResponse):
except Exception as e:
text = str(e)
raise InvariantAttributeError(text)
elif InvariantInputValidationError.catchphrase in text:
# get 'detail' from json
try:
error_json = await response.json()
text = error_json.get("detail", str(error_json))
# remove 'Invariant attribute error' from text
text = text.replace(InvariantInputValidationError.catchphrase, "")
except Exception as e:
text = str(e)
raise InvariantInputValidationError(text)
else:
raise ValueError("Invalid response from policy service: " + text)

Expand Down
40 changes: 28 additions & 12 deletions invariant/analyzer/runtime/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import sys
from dataclasses import dataclass # noqa: I001
from itertools import product
from typing import AsyncGenerator
from typing import AsyncGenerator, Awaitable, Callable, ParamSpec, TypeVar

import termcolor

P = ParamSpec("P")
R = TypeVar("R")
from invariant.analyzer.language.ast import (
ArrayLiteral,
BinaryExpr,
Expand Down Expand Up @@ -786,10 +788,6 @@ def _is_unknown(self, value):
async def visit_FunctionCall(self, node: FunctionCall):
function = await self.visit(node.name)
args = await asyncio.gather(*[self.visit(arg) for arg in node.args])

# only call functions, once all parameters are known
if function is Unknown or any(self._is_unknown(arg) for arg in args):
return Unknown
kwarg_items = await asyncio.gather(
*[
asyncio.gather(self.visit(entry.key), self.visit(entry.value))
Expand All @@ -798,6 +796,14 @@ async def visit_FunctionCall(self, node: FunctionCall):
)
kwargs = {k: v for k, v in kwarg_items}

# only call functions, once all parameters are known
if (
function is Unknown
or any(self._is_unknown(arg) for arg in args)
or any(self._is_unknown(v) for v in kwargs.values())
):
return Unknown

if isinstance(function, Declaration):
return await self.visit_PredicateCall(function, args, **kwargs)
else:
Expand Down Expand Up @@ -932,6 +938,9 @@ async def visit_ListComprehension(self, node: ListComprehension):
if iterable is None:
return []

if iterable is Unknown:
return Unknown

var_name = node.var_name.id if hasattr(node.var_name, "id") else node.var_name
results = []
original_vars = self.variable_store.copy()
Expand All @@ -953,14 +962,21 @@ async def visit_ListComprehension(self, node: ListComprehension):
self.variable_store = original_vars
return results

async def acall_function(self, function, *args, **kwargs):
async def acall_function(
self,
function: Callable[P, Awaitable[R]] | Callable[P, R] | CachedFunctionWrapper,
*args: P.args,
**kwargs: P.kwargs,
) -> R:
ctx: EvaluationContext = self.evaluation_context
# if function is a cached function wrapper, unwrap it (at this point we
# are already caching it)
if isinstance(function, CachedFunctionWrapper):
function = function.func
linked_function = ctx.link(function, None)
func = function.func if isinstance(function, CachedFunctionWrapper) else function
linked_function = ctx.link(func, None)
# also unwrap linked function, if a cached function wrapper
if isinstance(linked_function, CachedFunctionWrapper):
linked_function = linked_function.func
return await ctx.acall_function(linked_function, args, **kwargs)
linked_function = (
linked_function.func
if isinstance(linked_function, CachedFunctionWrapper)
else linked_function
)
return await ctx.acall_function(linked_function, *args, **kwargs)
6 changes: 4 additions & 2 deletions invariant/analyzer/runtime/evaluation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
In a separate file, for better separation of dependencies.
"""

from typing import Optional
from typing import Optional, TypeVar

from invariant.analyzer.runtime.input import Input
from invariant.analyzer.runtime.symbol_table import SymbolTable
from invariant.analyzer.language.ast import Node

R = TypeVar("R")

class EvaluationContext:
"""
Expand All @@ -28,7 +30,7 @@ def call_function(self, function, args, **kwargs):
async def acall_function(self, function, args, **kwargs):
raise NotImplementedError("EvaluationContext must implement acall_function()")

def link(self, function, node):
def link(self, function: R, node: Node | None) -> R:
if self.symbol_table:
return self.symbol_table.link(function, node)
else:
Expand Down
22 changes: 17 additions & 5 deletions invariant/analyzer/runtime/function_cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import inspect
from typing import Awaitable, Callable, ParamSpec, TypeVar

P = ParamSpec("P")
R = TypeVar("R")


class FunctionCache:
Expand Down Expand Up @@ -43,7 +47,12 @@ async def get(self, key):
async def set(self, key, value):
self.cache[key] = value

async def acall(self, function, args, **kwargs):
async def acall(
self,
function: Callable[P, Awaitable[R]] | Callable[P, R],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
# if function is not marked with @cached we just call it directly (see ./functions.py module)
if not hasattr(function, "__invariant_cache__"):
return await call_either_way(function, *args, **kwargs)
Expand All @@ -59,8 +68,11 @@ async def acall(self, function, args, **kwargs):
return value


async def call_either_way(fct, *args, **kwargs):
if inspect.iscoroutinefunction(fct):
return await fct(*args, **kwargs)
async def call_either_way(
func: Callable[P, Awaitable[R]] | Callable[P, R], *args: P.args, **kwargs: P.kwargs
) -> R:
if inspect.iscoroutinefunction(func):
return await func(*args, **kwargs)
else:
return fct(*args, **kwargs)
print([func, args, kwargs], flush=True)
return func(*args, **kwargs) # type: ignore
14 changes: 8 additions & 6 deletions invariant/analyzer/runtime/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
with special runtime attributes, relevant in the context of the
invariant agent analyzer.
"""
from typing import Callable, TypeVar, ParamSpec, Awaitable, Generic

P = ParamSpec("P")
R = TypeVar("R")

def cached(func):
def cached(func: Callable[P, Awaitable[R]] | Callable[P, R]) -> Callable[P, Awaitable[R]]:
"""
Decorator to mark a guardrailing function or built-in predicate as interpreter-cached.

Expand Down Expand Up @@ -38,19 +41,18 @@ def guardrailing_function(...):

return CachedFunctionWrapper(func)


class CachedFunctionWrapper:
class CachedFunctionWrapper(Generic[P, R]):
"""
Wraps a function such that is is always called via the current Interpreter instance.

This enables caching and other runtime features like function re-linking in a server context.
"""

def __init__(self, func):
def __init__(self, func: Callable[P, Awaitable[R]] | Callable[P, R]):
self.func = func
self.func.__invariant_cache__ = True
self.func.__invariant_cache__ = True # type: ignore

def __call__(self, *args, **kwargs):
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Awaitable[R]:
from invariant.analyzer.runtime.evaluation import Interpreter

return Interpreter.current().acall_function(self.func, *args, **kwargs)
Loading