Skip to content

Commit db671ba

Browse files
authored
Fix input/output message not redacted when guardrails_trace="enabled_full" (#1072)
* fix: detect guardrails with trace="enabled_full" Fix and simplify _find_detected_and_blocked_policy so that it correctly works even in case the guardrails assessments contains both detected and non-detected filters (as with guardrail_trace="enabled_full") * test: add bedrock int tests with different guardrail_trace levels * test: add xfail with guardrail_trace=disabled
1 parent 3b00110 commit db671ba

File tree

3 files changed

+115
-15
lines changed

3 files changed

+115
-15
lines changed

src/strands/models/bedrock.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import logging
99
import os
1010
import warnings
11-
from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union, cast
11+
from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union, ValuesView, cast
1212

1313
import boto3
1414
from botocore.config import Config as BotocoreConfig
@@ -878,18 +878,12 @@ def _find_detected_and_blocked_policy(self, input: Any) -> bool:
878878
if input.get("action") == "BLOCKED" and input.get("detected") and isinstance(input.get("detected"), bool):
879879
return True
880880

881-
# Recursively check all values in the dictionary
882-
for value in input.values():
883-
if isinstance(value, dict):
884-
return self._find_detected_and_blocked_policy(value)
885-
# Handle case where value is a list of dictionaries
886-
elif isinstance(value, list):
887-
for item in value:
888-
return self._find_detected_and_blocked_policy(item)
889-
elif isinstance(input, list):
890-
# Handle case where input is a list of dictionaries
891-
for item in input:
892-
return self._find_detected_and_blocked_policy(item)
881+
# Otherwise, recursively check all values in the dictionary
882+
return self._find_detected_and_blocked_policy(input.values())
883+
884+
elif isinstance(input, (list, ValuesView)):
885+
# Handle case where input is a list or dict_values
886+
return any(self._find_detected_and_blocked_policy(item) for item in input)
893887
# Otherwise return False
894888
return False
895889

tests/strands/models/test_bedrock.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,99 @@ async def test_stream_stream_input_guardrails(
663663
bedrock_client.converse_stream.assert_called_once_with(**request)
664664

665665

666+
@pytest.mark.asyncio
667+
async def test_stream_stream_input_guardrails_full_trace(
668+
bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist
669+
):
670+
"""Test guardrails are correctly detected also with guardrail_trace="enabled_full".
671+
In that case bedrock returns all filters, including those not detected/blocked."""
672+
metadata_event = {
673+
"metadata": {
674+
"usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0},
675+
"metrics": {"latencyMs": 245},
676+
"trace": {
677+
"guardrail": {
678+
"inputAssessment": {
679+
"jrv9qlue4hag": {
680+
"contentPolicy": {
681+
"filters": [
682+
{
683+
"action": "NONE",
684+
"confidence": "NONE",
685+
"detected": False,
686+
"filterStrength": "HIGH",
687+
"type": "SEXUAL",
688+
},
689+
{
690+
"action": "BLOCKED",
691+
"confidence": "LOW",
692+
"detected": True,
693+
"filterStrength": "HIGH",
694+
"type": "VIOLENCE",
695+
},
696+
{
697+
"action": "NONE",
698+
"confidence": "NONE",
699+
"detected": False,
700+
"filterStrength": "HIGH",
701+
"type": "HATE",
702+
},
703+
{
704+
"action": "NONE",
705+
"confidence": "NONE",
706+
"detected": False,
707+
"filterStrength": "HIGH",
708+
"type": "INSULTS",
709+
},
710+
{
711+
"action": "NONE",
712+
"confidence": "NONE",
713+
"detected": False,
714+
"filterStrength": "HIGH",
715+
"type": "PROMPT_ATTACK",
716+
},
717+
{
718+
"action": "NONE",
719+
"confidence": "NONE",
720+
"detected": False,
721+
"filterStrength": "HIGH",
722+
"type": "MISCONDUCT",
723+
},
724+
]
725+
}
726+
}
727+
}
728+
}
729+
},
730+
}
731+
}
732+
bedrock_client.converse_stream.return_value = {"stream": [metadata_event]}
733+
734+
request = {
735+
"additionalModelRequestFields": additional_request_fields,
736+
"inferenceConfig": {},
737+
"modelId": model_id,
738+
"messages": messages,
739+
"system": [],
740+
"toolConfig": {
741+
"tools": [{"toolSpec": tool_spec}],
742+
"toolChoice": {"auto": {}},
743+
},
744+
}
745+
746+
model.update_config(additional_request_fields=additional_request_fields)
747+
response = model.stream(messages, [tool_spec])
748+
749+
tru_chunks = await alist(response)
750+
exp_chunks = [
751+
{"redactContent": {"redactUserContentMessage": "[User input redacted.]"}},
752+
metadata_event,
753+
]
754+
755+
assert tru_chunks == exp_chunks
756+
bedrock_client.converse_stream.assert_called_once_with(**request)
757+
758+
666759
@pytest.mark.asyncio
667760
async def test_stream_stream_output_guardrails(
668761
bedrock_client, model, messages, tool_spec, model_id, additional_request_fields, alist

tests_integ/test_bedrock_guardrails.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,21 @@ def wait_for_guardrail_active(bedrock_client, guardrail_id, max_attempts=10, del
100100
raise RuntimeError("Guardrail did not become active.")
101101

102102

103-
def test_guardrail_input_intervention(boto_session, bedrock_guardrail):
103+
@pytest.mark.parametrize(
104+
"guardrail_trace",
105+
[
106+
pytest.param("disabled", marks=pytest.mark.xfail(reason='redact fails with trace="disabled"')),
107+
"enabled",
108+
"enabled_full",
109+
],
110+
)
111+
def test_guardrail_input_intervention(boto_session, bedrock_guardrail, guardrail_trace):
104112
bedrock_model = BedrockModel(
105113
guardrail_id=bedrock_guardrail,
106114
guardrail_version="DRAFT",
107115
boto_session=boto_session,
116+
guardrail_trace=guardrail_trace,
117+
guardrail_redact_input_message="Redacted.",
108118
)
109119

110120
agent = Agent(model=bedrock_model, system_prompt="You are a helpful assistant.", callback_handler=None)
@@ -116,6 +126,7 @@ def test_guardrail_input_intervention(boto_session, bedrock_guardrail):
116126
assert str(response1).strip() == BLOCKED_INPUT
117127
assert response2.stop_reason != "guardrail_intervened"
118128
assert str(response2).strip() != BLOCKED_INPUT
129+
assert agent.messages[0]["content"][0]["text"] == "Redacted."
119130

120131

121132
@pytest.mark.parametrize("processing_mode", ["sync", "async"])
@@ -159,13 +170,15 @@ def test_guardrail_output_intervention(boto_session, bedrock_guardrail, processi
159170
)
160171

161172

173+
@pytest.mark.parametrize("guardrail_trace", ["enabled", "enabled_full"])
162174
@pytest.mark.parametrize("processing_mode", ["sync", "async"])
163-
def test_guardrail_output_intervention_redact_output(bedrock_guardrail, processing_mode):
175+
def test_guardrail_output_intervention_redact_output(bedrock_guardrail, processing_mode, guardrail_trace):
164176
REDACT_MESSAGE = "Redacted."
165177
bedrock_model = BedrockModel(
166178
guardrail_id=bedrock_guardrail,
167179
guardrail_version="DRAFT",
168180
guardrail_stream_processing_mode=processing_mode,
181+
guardrail_trace=guardrail_trace,
169182
guardrail_redact_output=True,
170183
guardrail_redact_output_message=REDACT_MESSAGE,
171184
region_name="us-east-1",

0 commit comments

Comments
 (0)