Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions invariant/analyzer/language/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"tuple",
"tool_call",
"text",
"server",
"image",
"Tool",
"ToolParameter",
Expand Down
3 changes: 2 additions & 1 deletion invariant/analyzer/runtime/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ async def process_input(input_dict):
print()

# track number of rule body evaluations
evaluation_context.evaluation_counter += 1
evaluation_context.increment_evaluation_counter()

result, new_variable_domains, ranges = await Interpreter.eval(
expr_or_list,
Expand Down Expand Up @@ -418,6 +418,7 @@ def __init__(self, variable_store, globals, evaluation_context=None, partial=Tru

self.ranges = []

# output stream for printing
self.output_stream = sys.stdout

# variable ranges describe the domain of all encountered
Expand Down
19 changes: 17 additions & 2 deletions invariant/analyzer/runtime/evaluation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,40 @@
In a separate file, for better separation of dependencies.
"""

import os
from typing import Optional, TypeVar

from invariant.analyzer.language.ast import Node
from invariant.analyzer.runtime.input import Input
from invariant.analyzer.runtime.symbol_table import SymbolTable
from invariant.analyzer.language.ast import Node

R = TypeVar("R")


class EvaluationContext:
"""
An evaluation context enables a caller to handle the
evaluation of external functions explicitly (e.g. for caching)
and provide their own flow semantics (e.g. lookup in a graph).
"""

def __init__(self, symbol_table: Optional[SymbolTable] = None):
def __init__(
self, maximum_iterations: int | None = None, symbol_table: Optional[SymbolTable] = None
):
self.symbol_table = symbol_table

self.evaluation_counter = 0
self.maximum_iterations = int(os.environ.get("INVARIANT_MAX_ITERATIONS", 100))

def increment_evaluation_counter(self):
self.evaluation_counter += 1
if (
self.maximum_iterations is not None
and self.evaluation_counter > self.maximum_iterations
):
raise RuntimeError(
f"Maximum checking cycles exceeded: {self.evaluation_counter - 1} (please contact Invariant support if you need to increase this limit, or rewrite your rule to be less expensive)"
)

def call_function(self, function, args, **kwargs):
raise NotImplementedError("EvaluationContext must implement call_function()")
Expand Down
48 changes: 38 additions & 10 deletions invariant/analyzer/runtime/function_cache.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,36 @@
import asyncio
import inspect
from typing import Awaitable, Callable, ParamSpec, TypeVar

P = ParamSpec("P")
R = TypeVar("R")


class DictLock:
def __init__(self):
self.locks = {}
self.main_lock = asyncio.Lock()

async def key(self, key):
async with self.main_lock:
if key not in self.locks:
self.locks[key] = DictLockValue(key)
return self.locks[key]


class DictLockValue:
def __init__(self, value):
self.value = value
self.lock = asyncio.Lock()

async def __aenter__(self):
await self.lock.acquire()

async def __aexit__(self, exc_type, exc_value, traceback):
self.lock.release()
return False


class FunctionCache:
"""
The function cache is responsible for handling function calls in policy rules
Expand All @@ -16,13 +42,14 @@ class FunctionCache:

def __init__(self, cache=None):
self.cache = cache or {}
self.cache_locks = DictLock()

def clear(self):
self.cache = {}

def arg_key(self, arg):
# cache primitives by value
if type(arg) is int or type(arg) is float or type(arg) is str:
if type(arg) is int or type(arg) is float or type(arg) is str or type(arg) is bool:
return arg
# cache lists by id
elif type(arg) is list:
Expand Down Expand Up @@ -56,16 +83,17 @@ async def acall(
# if function is not marked with @cached we just call it directly (see ./functions.py module)
if not hasattr(function, "__invariant_cache__"):
return await call_either_way(function, *args, **kwargs)

key = self.call_key(function, args, kwargs)
value = await self.get(key)

if value is not None:
return value
else:
value = await call_either_way(function, *args, **kwargs)
await self.set(key, value)
return value

async with await self.cache_locks.key(key):
value = await self.get(key)

if value is not None:
return value
else:
value = await call_either_way(function, *args, **kwargs)
await self.set(key, value)
return value


async def call_either_way(
Expand Down
66 changes: 56 additions & 10 deletions invariant/analyzer/runtime/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import warnings
from collections.abc import ItemsView, KeysView, ValuesView
from copy import deepcopy
from typing import Callable, List, Optional
from typing import Callable, List, Optional, get_args

from pydantic import BaseModel
from rich.pretty import pprint as rich_print
Expand All @@ -27,6 +27,7 @@
ToolCall,
ToolOutput,
ToolParameter,
ToolParameterType,
)
from invariant.analyzer.runtime.runtime_errors import InvariantInputValidationError

Expand Down Expand Up @@ -381,6 +382,12 @@ def parse_input(self, input: list[dict]) -> list[Event]:
if not isinstance(event, dict):
parsed_data.append(event)
continue

# Extract relevant metadata from the event
if "metadata" in event:
if "server" in event["metadata"]:
event["server"] = event["metadata"]["server"]

if "role" in event:
if event["role"] != "tool":
# if arguments are given as string convert them into dict using json.loads(...)
Expand All @@ -394,7 +401,6 @@ def parse_input(self, input: list[dict]) -> list[Event]:
# # convert .content str to [{"type": "text": <content>}]
# if type(event.get("content")) is str:
# event["content"] = [{"type": "text", "text": event["content"]}]

msg = Message(**event)
parsed_data.append(msg)
if msg.tool_calls is not None:
Expand All @@ -420,22 +426,24 @@ def parse_tool_param(
name: str, schema: dict, required_keys: Optional[List[str]] = None
) -> ToolParameter:
param_type = schema.get("type", "string")
description = schema.get("description", "")
description = schema.get("description")
if description is None:
description = "no description available"

# Only object-level schemas have required fields as a list
if required_keys is None:
required_keys = schema.get("required", [])

aliases = {
"integer": "number",
"int": "number",
"int": "integer",
"long": "integer",
"float": "number",
"bool": "boolean",
"str": "string",
"dict": "object",
"list": "array",
}
if param_type in aliases:
if isinstance(param_type, str) and param_type in aliases:
param_type = aliases[param_type]

if param_type == "object":
Expand All @@ -446,30 +454,56 @@ def parse_tool_param(
schema=subschema,
required_keys=schema.get("required", []),
)
additional_properties = (
bool(schema.get("additionalProperties"))
if schema.get("additionalProperties") is not None
else None
)

return ToolParameter(
name=name,
type="object",
description=description,
required=name in required_keys,
properties=properties,
additionalProperties=schema.get("additionalProperties"),
additionalProperties=additional_properties,
)
elif param_type == "array":
return ToolParameter(
name=name,
type="array",
description=description,
required=name in required_keys,
items=parse_tool_param(name=f"{name} item", schema=schema["items"]),
items=parse_tool_param(
name=f"{name} item", schema=schema["items"], required_keys=[]
)
if "items" in schema
else None,
)
elif param_type in ["object", "array", "string", "number", "boolean"]:
elif param_type in ["string", "number", "integer", "boolean"]:
return ToolParameter(
name=name,
type=param_type,
description=description,
required=name in required_keys,
enum=schema.get("enum"),
)
elif isinstance(param_type, list):
required = name in required_keys
for param in param_type:
if "null" in param:
required = False
continue
if param not in get_args(ToolParameterType):
raise InvariantInputValidationError(
f"Unsupported schema type: {param} for parameter {name}. Supported types are: object, array, string, number, boolean."
)
return ToolParameter(
name=name,
type=[p for p in param_type if p != "null"],
description=description,
required=required,
)
else:
raise InvariantInputValidationError(
f"Unsupported schema type: {param_type} for parameter {name}. Supported types are: object, array, string, number, boolean."
Expand All @@ -488,10 +522,22 @@ def parse_tool_param(
)
)

tool_desc = tool.get("description")
if tool_desc is None:
tool_desc = "no description available"
if not isinstance(tool_desc, str):
try:
tool_desc = str(tool_desc)
except Exception:
raise InvariantInputValidationError(
f"Tool description should be a string. Instead, got: {tool_desc} of type {type(tool_desc)}"
)
server = tool["metadata"].get("server", None) if "metadata" in tool else None
tool_obj = Tool(
name=name,
description=tool["description"],
description=tool_desc,
inputSchema=properties,
server=server,
)
parsed_data.append(tool_obj)
else:
Expand Down
16 changes: 14 additions & 2 deletions invariant/analyzer/runtime/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class Event(BaseModel):
default_factory=dict, description="Metadata associated with the event"
)

server: Optional[str] = None


class Function(BaseModel):
name: str
Expand Down Expand Up @@ -309,8 +311,18 @@ def image(*args) -> list[str]:
return result


ToolParameterType = Literal[
"object",
"array",
"string",
"number",
"integer",
"boolean",
] # extend as needed


class ToolParameter(BaseModel):
type: Literal["object", "array", "string", "number", "boolean"] # extend as needed
type: ToolParameterType | list[ToolParameterType]
name: str
description: str
required: bool = False
Expand Down Expand Up @@ -351,7 +363,7 @@ class Tool(Event):
inputSchema: list[ToolParameter]

def __invariant_attribute__(self, name: str):
if name in ["name", "description", "inputSchema"]:
if name in ["name", "description", "inputSchema", "server"]:
return getattr(self, name)
raise InvariantAttributeError(
f"Attribute {name} not found in Tool. Available attributes are: name, description, inputSchema"
Expand Down
18 changes: 14 additions & 4 deletions invariant/analyzer/runtime/rule.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import textwrap
from typing import Optional, Callable, TypeVar, ParamSpec, Awaitable
from typing import Awaitable, Callable, Optional, ParamSpec, TypeVar

import invariant.analyzer.language.ast as ast
from invariant.analyzer.language.linking import link
Expand All @@ -18,6 +18,7 @@
P = ParamSpec("P")
R = TypeVar("R")


class PolicyAction:
def __call__(self, input_dict):
raise NotImplementedError()
Expand Down Expand Up @@ -154,14 +155,20 @@ def from_raise_policy(cls, policy: ast.RaisePolicy, globals):

class InputEvaluationContext(EvaluationContext):
def __init__(
self, input: Input, rule_set: "RuleSet", policy_parameters, symbol_table: Optional[SymbolTable]
self,
input: Input,
rule_set: "RuleSet",
policy_parameters,
symbol_table: Optional[SymbolTable],
):
super().__init__(symbol_table=symbol_table)
self.input = input
self.rule_set = rule_set
self.policy_parameters = policy_parameters

async def acall_function(self, func: Callable[P, Awaitable[R]] | Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
async def acall_function(
self, func: Callable[P, Awaitable[R]] | Callable[P, R], *args: P.args, **kwargs: P.kwargs
) -> R:
return await self.rule_set.acall_function(func, *args, **kwargs)

def has_flow(self, a, b):
Expand Down Expand Up @@ -211,7 +218,9 @@ def instance(self, cache: Optional[FunctionCache] = None):
def __del__(self):
self.function_cache.clear()

async def acall_function(self, func: Callable[P, Awaitable[R]] | Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
async def acall_function(
self, func: Callable[P, Awaitable[R]] | Callable[P, R], *args: P.args, **kwargs: P.kwargs
) -> R:
return await self.function_cache.acall(func, *args, **kwargs)

def log_apply(self, rule, model):
Expand Down Expand Up @@ -248,6 +257,7 @@ async def apply(self, input_data: Input, policy_parameters):
exceptions.extend(error)

self.input = None

return exceptions

def __str__(self):
Expand Down
Loading