diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 3505dad07..430b3dbee 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -1248,6 +1248,18 @@ async def generate_async( new_message["tool_calls"] = tool_calls return new_message + def _validate_streaming_with_output_rails(self) -> None: + if len(self.config.rails.output.flows) > 0 and ( + not self.config.rails.output.streaming + or not self.config.rails.output.streaming.enabled + ): + raise ValueError( + "stream_async() cannot be used when output rails are configured but " + "rails.output.streaming.enabled is False. Either set " + "rails.output.streaming.enabled to True in your configuration, or use " + "generate_async() instead of stream_async()." + ) + def stream_async( self, prompt: Optional[str] = None, @@ -1259,6 +1271,7 @@ def stream_async( ) -> AsyncIterator[str]: """Simplified interface for getting directly the streamed tokens from the LLM.""" + self._validate_streaming_with_output_rails() # if an external generator is provided, use it directly if generator: if ( diff --git a/tests/test_parallel_streaming_output_rails.py b/tests/test_parallel_streaming_output_rails.py index 8b1166ef8..4359388fd 100644 --- a/tests/test_parallel_streaming_output_rails.py +++ b/tests/test_parallel_streaming_output_rails.py @@ -605,21 +605,24 @@ async def test_parallel_streaming_output_rails_performance_benefits(): async def test_parallel_streaming_output_rails_default_config_behavior( parallel_output_rails_default_config, ): - """Tests parallel output rails with default streaming configuration""" + """Tests that stream_async raises an error with default config (no explicit streaming config)""" - llm_completions = [ - ' express greeting\nbot express greeting\n "Hi, how are you doing?"', - ' "This is a test message with default streaming config."', - ] + from nemoguardrails import LLMRails - chunks = await run_parallel_self_check_test( - parallel_output_rails_default_config, llm_completions - ) + llmrails = LLMRails(parallel_output_rails_default_config) - response = "".join(chunks) - assert len(response) > 0 - assert len(chunks) > 0 - assert "test message" in response + with pytest.raises(ValueError) as exc_info: + async for chunk in llmrails.stream_async( + messages=[{"role": "user", "content": "Hi!"}] + ): + pass + + assert str(exc_info.value) == ( + "stream_async() cannot be used when output rails are configured but " + "rails.output.streaming.enabled is False. Either set " + "rails.output.streaming.enabled to True in your configuration, or use " + "generate_async() instead of stream_async()." + ) await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 114bb7dd1..8fb1ac22a 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -474,6 +474,95 @@ def _calculate_number_of_actions(input_length, chunk_size, context_size): return math.ceil((input_length - context_size) / (chunk_size - context_size)) +@pytest.mark.asyncio +async def test_streaming_with_output_rails_disabled_raises_error(): + config = RailsConfig.from_content( + config={ + "models": [], + "rails": { + "output": { + "flows": {"self check output"}, + "streaming": { + "enabled": False, + }, + } + }, + "streaming": True, + "prompts": [{"task": "self_check_output", "content": "a test template"}], + }, + colang_content=""" + define user express greeting + "hi" + + define flow + user express greeting + bot tell joke + """, + ) + + chat = TestChat( + config, + llm_completions=[], + streaming=True, + ) + + with pytest.raises(ValueError) as exc_info: + async for chunk in chat.app.stream_async( + messages=[{"role": "user", "content": "Hi!"}], + ): + pass + + assert str(exc_info.value) == ( + "stream_async() cannot be used when output rails are configured but " + "rails.output.streaming.enabled is False. Either set " + "rails.output.streaming.enabled to True in your configuration, or use " + "generate_async() instead of stream_async()." + ) + + +@pytest.mark.asyncio +async def test_streaming_with_output_rails_no_streaming_config_raises_error(): + config = RailsConfig.from_content( + config={ + "models": [], + "rails": { + "output": { + "flows": {"self check output"}, + } + }, + "streaming": True, + "prompts": [{"task": "self_check_output", "content": "a test template"}], + }, + colang_content=""" + define user express greeting + "hi" + + define flow + user express greeting + bot tell joke + """, + ) + + chat = TestChat( + config, + llm_completions=[], + streaming=True, + ) + + with pytest.raises(ValueError) as exc_info: + async for chunk in chat.app.stream_async( + messages=[{"role": "user", "content": "Hi!"}], + ): + pass + + assert str(exc_info.value) == ( + "stream_async() cannot be used when output rails are configured but " + "rails.output.streaming.enabled is False. Either set " + "rails.output.streaming.enabled to True in your configuration, or use " + "generate_async() instead of stream_async()." + ) + + @pytest.mark.asyncio async def test_streaming_error_handling(): """Test that errors during streaming are properly formatted and returned.""" diff --git a/tests/test_streaming_output_rails.py b/tests/test_streaming_output_rails.py index 501353c1b..8583a2fb9 100644 --- a/tests/test_streaming_output_rails.py +++ b/tests/test_streaming_output_rails.py @@ -87,18 +87,6 @@ def output_rails_streaming_config_default(): ) -@pytest.mark.asyncio -async def test_stream_async_streaming_disabled(output_rails_streaming_config_default): - """Tests if stream_async returns a StreamingHandler instance when streaming is disabled""" - - llmrails = LLMRails(output_rails_streaming_config_default) - - result = llmrails.stream_async(prompt="test") - assert isinstance( - result, StreamingHandler - ), "Expected StreamingHandler instance when streaming is disabled" - - @pytest.mark.asyncio async def test_stream_async_streaming_enabled(output_rails_streaming_config): """Tests if stream_async returns does not return StreamingHandler instance when streaming is enabled""" @@ -175,33 +163,23 @@ async def test_streaming_output_rails_blocked_explicit(output_rails_streaming_co async def test_streaming_output_rails_blocked_default_config( output_rails_streaming_config_default, ): - """Tests if output rails streaming default config do not block content with BLOCK keyword""" + """Tests that stream_async raises an error with default config (output rails without explicit streaming config)""" - # text with a BLOCK keyword - llm_completions = [ - ' express greeting\nbot express greeting\n "Hi, how are you doing?"', - ' "This is a [BLOCK] joke that should be blocked."', - ] + llmrails = LLMRails(output_rails_streaming_config_default) - chunks = await run_self_check_test( - output_rails_streaming_config_default, llm_completions + with pytest.raises(ValueError) as exc_info: + async for chunk in llmrails.stream_async( + messages=[{"role": "user", "content": "Hi!"}] + ): + pass + + assert str(exc_info.value) == ( + "stream_async() cannot be used when output rails are configured but " + "rails.output.streaming.enabled is False. Either set " + "rails.output.streaming.enabled to True in your configuration, or use " + "generate_async() instead of stream_async()." ) - expected_error = { - "error": { - "message": "Blocked by self check output rails.", - "type": "guardrails_violation", - "param": "self check output", - "code": "content_blocked", - } - } - - error_chunks = [ - json.loads(chunk) for chunk in chunks if chunk.startswith('{"error":') - ] - assert len(error_chunks) == 0 - assert expected_error not in error_chunks - await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}) @@ -231,27 +209,6 @@ async def test_streaming_output_rails_blocked_at_start(output_rails_streaming_co await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}) -@pytest.mark.asyncio -async def test_streaming_output_rails_default_config_not_blocked_at_start( - output_rails_streaming_config_default, -): - """Tests blocking with BLOCK at the very beginning of the response does not return abort sse""" - - llm_completions = [ - ' express greeting\nbot express greeting\n "Hi, how are you doing?"', - ' "[BLOCK] This should be blocked immediately at the start."', - ] - - chunks = await run_self_check_test( - output_rails_streaming_config_default, llm_completions - ) - - with pytest.raises(JSONDecodeError): - json.loads(chunks[0]) - - await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()}) - - async def simple_token_generator() -> AsyncIterator[str]: """Simple generator that yields tokens.""" tokens = ["Hello", " ", "world", "!"]