diff --git a/invariant/analyzer/runtime/function_cache.py b/invariant/analyzer/runtime/function_cache.py index e3f1f54..11142e2 100644 --- a/invariant/analyzer/runtime/function_cache.py +++ b/invariant/analyzer/runtime/function_cache.py @@ -74,5 +74,4 @@ async def call_either_way( if inspect.iscoroutinefunction(func): return await func(*args, **kwargs) else: - print([func, args, kwargs], flush=True) return func(*args, **kwargs) # type: ignore diff --git a/invariant/analyzer/runtime/nodes.py b/invariant/analyzer/runtime/nodes.py index bafe9a5..420e20b 100644 --- a/invariant/analyzer/runtime/nodes.py +++ b/invariant/analyzer/runtime/nodes.py @@ -135,9 +135,11 @@ class Contents(RootModel[list[Chunk]]): def __contains__(self, item: object) -> bool: for chunk in self.root: if isinstance(chunk, TextChunk): - return item in chunk.text + if item in chunk.text: + return True elif isinstance(chunk, Image): - return item in chunk.image_url["url"] + if item in chunk.image_url["url"]: + return True return False def __invariant_attribute__(self, name: str): @@ -324,7 +326,16 @@ class ToolParameter(BaseModel): enum: Optional[List[str]] = None def __invariant_attribute__(self, name: str): - if name in ["type", "name", "description", "required", "properties", "additionalProperties", "items", "enum"]: + if name in [ + "type", + "name", + "description", + "required", + "properties", + "additionalProperties", + "items", + "enum", + ]: return getattr(self, name) raise InvariantAttributeError( f"Attribute {name} not found in ToolParameter. Available attributes are: type, name, description, required, properties, additionalProperties, items, enum" diff --git a/invariant/tests/analyzer/test_chunked.py b/invariant/tests/analyzer/test_chunked.py index 7a853aa..33a498a 100644 --- a/invariant/tests/analyzer/test_chunked.py +++ b/invariant/tests/analyzer/test_chunked.py @@ -1,6 +1,6 @@ import unittest -from invariant.analyzer import Monitor +from invariant.analyzer import Monitor, Policy from invariant.analyzer.traces import chunked @@ -86,3 +86,56 @@ def test_simple_but_multiple_text_chunks(self): input.extend(pending_input) assert len(errors) == 0, "Expected no errors, but got: " + str(errors) + + def test_contains_in_text_chunk(self): + # tests that 'abc' in message.content works both when 'abc' is in chunk 0 or chunk 1 + policy = Policy.from_string( + """ + +raise "pattern found" if: + (msg: Message) + msg.role == "assistant" + "abc" in msg.content +""" + ) + + # in second chunk + input = [ + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Hello, "}, + {"type": "text", "text": "aa abc aa"}, + ], + } + ] + + result = policy.analyze(input, []) + assert len(result.errors) == 1, "Expected one error, but got: " + str(result) + + # in no chunk + input = [ + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Hello, "}, + {"type": "text", "text": "adc"}, + ], + } + ] + + result = policy.analyze(input, []) + assert len(result.errors) == 0, "Expected no errors, but got: " + str(result) + + # in first chunk + input = [ + { + "role": "assistant", + "content": [ + {"type": "text", "text": "aa abc aa"}, + {"type": "text", "text": "Hello, "}, + ], + } + ] + result = policy.analyze(input, []) + assert len(result.errors) == 1, "Expected one error, but got: " + str(result) diff --git a/uv.lock b/uv.lock index 08cbc11..c75005a 100644 --- a/uv.lock +++ b/uv.lock @@ -842,7 +842,7 @@ wheels = [ [[package]] name = "invariant-ai" -version = "0.3" +version = "0.3.1" source = { virtual = "." } dependencies = [ { name = "aiohttp" },