Skip to content

Commit e29dfd6

Browse files
committed
fix(streaming)!: raise error when stream_async used with disabled output rails streaming (#1470)
* fix(streaming): raise error when stream_async used with disabled output rails streaming When output rails are configured but output.streaming.enabled is False (or not set), calling stream_async() would result in undefined behavior or hangs due to the conflict between streaming expectations and blocking output rail processing. This change adds explicit validation in stream_async() to detect this misconfiguration and raise a clear ValueError with actionable guidance: - Set rails.output.streaming.enabled = True to use streaming with output rails - Use generate_async() instead for non-streaming with output rails Updated affected tests to expect and validate the new error behavior instead of relying on the previous buggy behavior.
1 parent 20d86ab commit e29dfd6

File tree

4 files changed

+130
-68
lines changed

4 files changed

+130
-68
lines changed

nemoguardrails/rails/llm/llmrails.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,6 +1240,18 @@ async def generate_async(
12401240
new_message["tool_calls"] = tool_calls
12411241
return new_message
12421242

1243+
def _validate_streaming_with_output_rails(self) -> None:
1244+
if len(self.config.rails.output.flows) > 0 and (
1245+
not self.config.rails.output.streaming
1246+
or not self.config.rails.output.streaming.enabled
1247+
):
1248+
raise ValueError(
1249+
"stream_async() cannot be used when output rails are configured but "
1250+
"rails.output.streaming.enabled is False. Either set "
1251+
"rails.output.streaming.enabled to True in your configuration, or use "
1252+
"generate_async() instead of stream_async()."
1253+
)
1254+
12431255
def stream_async(
12441256
self,
12451257
prompt: Optional[str] = None,
@@ -1251,6 +1263,7 @@ def stream_async(
12511263
) -> AsyncIterator[str]:
12521264
"""Simplified interface for getting directly the streamed tokens from the LLM."""
12531265

1266+
self._validate_streaming_with_output_rails()
12541267
# if an external generator is provided, use it directly
12551268
if generator:
12561269
if (

tests/test_parallel_streaming_output_rails.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -605,21 +605,24 @@ async def test_parallel_streaming_output_rails_performance_benefits():
605605
async def test_parallel_streaming_output_rails_default_config_behavior(
606606
parallel_output_rails_default_config,
607607
):
608-
"""Tests parallel output rails with default streaming configuration"""
608+
"""Tests that stream_async raises an error with default config (no explicit streaming config)"""
609609

610-
llm_completions = [
611-
' express greeting\nbot express greeting\n "Hi, how are you doing?"',
612-
' "This is a test message with default streaming config."',
613-
]
610+
from nemoguardrails import LLMRails
614611

615-
chunks = await run_parallel_self_check_test(
616-
parallel_output_rails_default_config, llm_completions
617-
)
612+
llmrails = LLMRails(parallel_output_rails_default_config)
618613

619-
response = "".join(chunks)
620-
assert len(response) > 0
621-
assert len(chunks) > 0
622-
assert "test message" in response
614+
with pytest.raises(ValueError) as exc_info:
615+
async for chunk in llmrails.stream_async(
616+
messages=[{"role": "user", "content": "Hi!"}]
617+
):
618+
pass
619+
620+
assert str(exc_info.value) == (
621+
"stream_async() cannot be used when output rails are configured but "
622+
"rails.output.streaming.enabled is False. Either set "
623+
"rails.output.streaming.enabled to True in your configuration, or use "
624+
"generate_async() instead of stream_async()."
625+
)
623626

624627
await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()})
625628

tests/test_streaming.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,95 @@ def _calculate_number_of_actions(input_length, chunk_size, context_size):
474474
return math.ceil((input_length - context_size) / (chunk_size - context_size))
475475

476476

477+
@pytest.mark.asyncio
478+
async def test_streaming_with_output_rails_disabled_raises_error():
479+
config = RailsConfig.from_content(
480+
config={
481+
"models": [],
482+
"rails": {
483+
"output": {
484+
"flows": {"self check output"},
485+
"streaming": {
486+
"enabled": False,
487+
},
488+
}
489+
},
490+
"streaming": True,
491+
"prompts": [{"task": "self_check_output", "content": "a test template"}],
492+
},
493+
colang_content="""
494+
define user express greeting
495+
"hi"
496+
497+
define flow
498+
user express greeting
499+
bot tell joke
500+
""",
501+
)
502+
503+
chat = TestChat(
504+
config,
505+
llm_completions=[],
506+
streaming=True,
507+
)
508+
509+
with pytest.raises(ValueError) as exc_info:
510+
async for chunk in chat.app.stream_async(
511+
messages=[{"role": "user", "content": "Hi!"}],
512+
):
513+
pass
514+
515+
assert str(exc_info.value) == (
516+
"stream_async() cannot be used when output rails are configured but "
517+
"rails.output.streaming.enabled is False. Either set "
518+
"rails.output.streaming.enabled to True in your configuration, or use "
519+
"generate_async() instead of stream_async()."
520+
)
521+
522+
523+
@pytest.mark.asyncio
524+
async def test_streaming_with_output_rails_no_streaming_config_raises_error():
525+
config = RailsConfig.from_content(
526+
config={
527+
"models": [],
528+
"rails": {
529+
"output": {
530+
"flows": {"self check output"},
531+
}
532+
},
533+
"streaming": True,
534+
"prompts": [{"task": "self_check_output", "content": "a test template"}],
535+
},
536+
colang_content="""
537+
define user express greeting
538+
"hi"
539+
540+
define flow
541+
user express greeting
542+
bot tell joke
543+
""",
544+
)
545+
546+
chat = TestChat(
547+
config,
548+
llm_completions=[],
549+
streaming=True,
550+
)
551+
552+
with pytest.raises(ValueError) as exc_info:
553+
async for chunk in chat.app.stream_async(
554+
messages=[{"role": "user", "content": "Hi!"}],
555+
):
556+
pass
557+
558+
assert str(exc_info.value) == (
559+
"stream_async() cannot be used when output rails are configured but "
560+
"rails.output.streaming.enabled is False. Either set "
561+
"rails.output.streaming.enabled to True in your configuration, or use "
562+
"generate_async() instead of stream_async()."
563+
)
564+
565+
477566
@pytest.mark.asyncio
478567
async def test_streaming_error_handling():
479568
"""Test that errors during streaming are properly formatted and returned."""

tests/test_streaming_output_rails.py

Lines changed: 13 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -87,18 +87,6 @@ def output_rails_streaming_config_default():
8787
)
8888

8989

90-
@pytest.mark.asyncio
91-
async def test_stream_async_streaming_disabled(output_rails_streaming_config_default):
92-
"""Tests if stream_async returns a StreamingHandler instance when streaming is disabled"""
93-
94-
llmrails = LLMRails(output_rails_streaming_config_default)
95-
96-
result = llmrails.stream_async(prompt="test")
97-
assert isinstance(
98-
result, StreamingHandler
99-
), "Expected StreamingHandler instance when streaming is disabled"
100-
101-
10290
@pytest.mark.asyncio
10391
async def test_stream_async_streaming_enabled(output_rails_streaming_config):
10492
"""Tests if stream_async returns does not return StreamingHandler instance when streaming is enabled"""
@@ -175,33 +163,23 @@ async def test_streaming_output_rails_blocked_explicit(output_rails_streaming_co
175163
async def test_streaming_output_rails_blocked_default_config(
176164
output_rails_streaming_config_default,
177165
):
178-
"""Tests if output rails streaming default config do not block content with BLOCK keyword"""
166+
"""Tests that stream_async raises an error with default config (output rails without explicit streaming config)"""
179167

180-
# text with a BLOCK keyword
181-
llm_completions = [
182-
' express greeting\nbot express greeting\n "Hi, how are you doing?"',
183-
' "This is a [BLOCK] joke that should be blocked."',
184-
]
168+
llmrails = LLMRails(output_rails_streaming_config_default)
185169

186-
chunks = await run_self_check_test(
187-
output_rails_streaming_config_default, llm_completions
170+
with pytest.raises(ValueError) as exc_info:
171+
async for chunk in llmrails.stream_async(
172+
messages=[{"role": "user", "content": "Hi!"}]
173+
):
174+
pass
175+
176+
assert str(exc_info.value) == (
177+
"stream_async() cannot be used when output rails are configured but "
178+
"rails.output.streaming.enabled is False. Either set "
179+
"rails.output.streaming.enabled to True in your configuration, or use "
180+
"generate_async() instead of stream_async()."
188181
)
189182

190-
expected_error = {
191-
"error": {
192-
"message": "Blocked by self check output rails.",
193-
"type": "guardrails_violation",
194-
"param": "self check output",
195-
"code": "content_blocked",
196-
}
197-
}
198-
199-
error_chunks = [
200-
json.loads(chunk) for chunk in chunks if chunk.startswith('{"error":')
201-
]
202-
assert len(error_chunks) == 0
203-
assert expected_error not in error_chunks
204-
205183
await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()})
206184

207185

@@ -231,27 +209,6 @@ async def test_streaming_output_rails_blocked_at_start(output_rails_streaming_co
231209
await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()})
232210

233211

234-
@pytest.mark.asyncio
235-
async def test_streaming_output_rails_default_config_not_blocked_at_start(
236-
output_rails_streaming_config_default,
237-
):
238-
"""Tests blocking with BLOCK at the very beginning of the response does not return abort sse"""
239-
240-
llm_completions = [
241-
' express greeting\nbot express greeting\n "Hi, how are you doing?"',
242-
' "[BLOCK] This should be blocked immediately at the start."',
243-
]
244-
245-
chunks = await run_self_check_test(
246-
output_rails_streaming_config_default, llm_completions
247-
)
248-
249-
with pytest.raises(JSONDecodeError):
250-
json.loads(chunks[0])
251-
252-
await asyncio.gather(*asyncio.all_tasks() - {asyncio.current_task()})
253-
254-
255212
async def simple_token_generator() -> AsyncIterator[str]:
256213
"""Simple generator that yields tokens."""
257214
tokens = ["Hello", " ", "world", "!"]

0 commit comments

Comments
 (0)