diff --git a/models/src/agent_control_models/__init__.py b/models/src/agent_control_models/__init__.py index cdec174d..32e9b997 100644 --- a/models/src/agent_control_models/__init__.py +++ b/models/src/agent_control_models/__init__.py @@ -12,6 +12,8 @@ expand_action_filter, normalize_action, normalize_action_list, + validate_action, + validate_action_list, ) from .agent import ( BUILTIN_STEP_TYPES, @@ -144,6 +146,8 @@ "UnrenderedTemplateControl", "normalize_action", "normalize_action_list", + "validate_action", + "validate_action_list", "expand_action_filter", # Error models "ProblemDetail", diff --git a/models/src/agent_control_models/actions.py b/models/src/agent_control_models/actions.py index a890c6f9..972f5008 100644 --- a/models/src/agent_control_models/actions.py +++ b/models/src/agent_control_models/actions.py @@ -7,6 +7,7 @@ type ActionDecision = Literal["deny", "steer", "observe"] +_CANONICAL_ACTIONS = frozenset({"deny", "steer", "observe"}) _OBSERVE_ACTION_ALIASES = frozenset({"allow", "observe", "warn", "log"}) _ACTION_QUERY_EXPANSION: dict[ActionDecision, tuple[str, ...]] = { "deny": ("deny",), @@ -15,15 +16,44 @@ } +def validate_action(action: str) -> ActionDecision: + """Validate that *action* is one of the canonical action values. + + Use this on public API boundaries (control create/update, query filters) + where legacy values should be rejected. + """ + if action in _CANONICAL_ACTIONS: + return cast(ActionDecision, action) + raise ValueError( + f"Invalid action {action!r}. Must be one of: deny, steer, observe." + ) + + +def validate_action_list(actions: Sequence[str]) -> list[ActionDecision]: + """Validate a list of actions, preserving order and removing duplicates.""" + validated: list[ActionDecision] = [] + seen: set[ActionDecision] = set() + for action in actions: + canonical = validate_action(action) + if canonical in seen: + continue + seen.add(canonical) + validated.append(canonical) + return validated + + def normalize_action(action: str) -> ActionDecision: - """Normalize a public or legacy action name to the canonical action.""" + """Normalize a stored or legacy action name to the canonical action. + + Use this on internal read paths (deserializing DB rows, server responses) + where historical data may contain legacy values. + """ if action in _OBSERVE_ACTION_ALIASES: return "observe" if action in ("deny", "steer"): return cast(ActionDecision, action) raise ValueError( - "Invalid action. Expected one of: deny, steer, observe " - "(legacy aliases allow/warn/log are also accepted temporarily)." + f"Invalid action {action!r}. Expected one of: deny, steer, observe." ) diff --git a/models/src/agent_control_models/controls.py b/models/src/agent_control_models/controls.py index b9d9de53..b34e7056 100644 --- a/models/src/agent_control_models/controls.py +++ b/models/src/agent_control_models/controls.py @@ -11,7 +11,7 @@ import re2 from pydantic import ConfigDict, Field, ValidationInfo, field_validator, model_validator -from .actions import ActionDecision, normalize_action +from .actions import ActionDecision, normalize_action, validate_action from .agent import JSONValue from .base import BaseModel @@ -520,8 +520,8 @@ class ControlAction(BaseModel): @field_validator("decision", mode="before") @classmethod - def normalize_decision(cls, value: str) -> ActionDecision: - return normalize_action(value) + def validate_decision(cls, value: str) -> ActionDecision: + return validate_action(value) MAX_CONDITION_DEPTH = 6 diff --git a/models/src/agent_control_models/observability.py b/models/src/agent_control_models/observability.py index a78e4b1e..dbd11fac 100644 --- a/models/src/agent_control_models/observability.py +++ b/models/src/agent_control_models/observability.py @@ -14,7 +14,11 @@ from pydantic import Field, field_validator -from .actions import ActionDecision, normalize_action, normalize_action_list +from .actions import ( + ActionDecision, + normalize_action, + validate_action_list, +) from .agent import AGENT_NAME_MIN_LENGTH, AGENT_NAME_PATTERN, normalize_agent_name from .base import BaseModel @@ -343,12 +347,12 @@ def validate_and_normalize_agent_name( @field_validator("actions", mode="before") @classmethod - def normalize_actions_filter( + def validate_actions_filter( cls, value: list[str] | None ) -> list[ActionDecision] | None: if value is None: return None - return normalize_action_list(value) + return validate_action_list(value) class EventQueryResponse(BaseModel): diff --git a/models/tests/test_actions.py b/models/tests/test_actions.py index c9a8c0cf..2b0c5775 100644 --- a/models/tests/test_actions.py +++ b/models/tests/test_actions.py @@ -1,4 +1,4 @@ -"""Tests for shared control-action compatibility behavior.""" +"""Tests for shared control-action types, validation, and normalization.""" from __future__ import annotations @@ -11,50 +11,240 @@ EvaluatorResult, expand_action_filter, ) +from agent_control_models.actions import normalize_action, validate_action from pydantic import ValidationError -def test_event_query_actions_normalize_and_expand_for_legacy_observability() -> None: - # Given: a query that mixes canonical and legacy advisory action names - query = EventQueryRequest( - actions=["warn", "observe", "deny", "log", "deny", "steer", "allow", "steer"] - ) +# --------------------------------------------------------------------------- +# validate_action (strict, for API boundaries) +# --------------------------------------------------------------------------- - # When: expanding the normalized public action filter for stored event rows - expanded = expand_action_filter(query.actions or []) - # Then: the public filter is canonicalized, deduped, and expanded for legacy rows - assert query.actions == ["observe", "deny", "steer"] - assert expanded == ["observe", "allow", "warn", "log", "deny", "steer"] +class TestValidateAction: + """Tests for the strict validate_action used on public API boundaries.""" + @pytest.mark.parametrize("action", ["deny", "steer", "observe"]) + def test_accepts_canonical_actions(self, action: str) -> None: + # Given: a canonical action name + # When: validating the action + result = validate_action(action) -def test_invalid_action_is_rejected_across_public_model_boundaries() -> None: - # Given: the same invalid action at each public model boundary - invalid_action = "block" - invalid_builders = [ - lambda: ControlAction.model_validate({"decision": invalid_action}), - lambda: ControlMatch( - control_id=123, - control_name="pii-check", - action=invalid_action, + # Then: the same canonical value is returned + assert result == action + + @pytest.mark.parametrize("legacy", ["allow", "warn", "log"]) + def test_rejects_legacy_actions(self, legacy: str) -> None: + # Given: a legacy action name that is no longer accepted at API boundaries + # When / Then: validation raises ValueError + with pytest.raises(ValueError, match="Invalid action"): + validate_action(legacy) + + def test_rejects_unknown_action(self) -> None: + # Given: a completely unknown action name + # When / Then: validation raises ValueError + with pytest.raises(ValueError, match="Invalid action"): + validate_action("block") + + +# --------------------------------------------------------------------------- +# normalize_action (lenient, for internal read paths) +# --------------------------------------------------------------------------- + + +class TestNormalizeAction: + """Tests for the lenient normalize_action used on read paths.""" + + @pytest.mark.parametrize("action", ["deny", "steer", "observe"]) + def test_passes_canonical_actions(self, action: str) -> None: + # Given: a canonical action name + # When: normalizing + result = normalize_action(action) + + # Then: the same value is returned unchanged + assert result == action + + @pytest.mark.parametrize("legacy", ["allow", "warn", "log"]) + def test_normalizes_legacy_to_observe(self, legacy: str) -> None: + # Given: a legacy advisory action stored in a historical DB row + # When: normalizing on the read path + result = normalize_action(legacy) + + # Then: it maps to the canonical "observe" action + assert result == "observe" + + def test_rejects_unknown_action(self) -> None: + # Given: a completely unknown action name + # When / Then: normalization raises ValueError even on the lenient path + with pytest.raises(ValueError, match="Invalid action"): + normalize_action("block") + + +# --------------------------------------------------------------------------- +# ControlAction (API input boundary — strict) +# --------------------------------------------------------------------------- + + +class TestControlActionValidation: + """ControlAction.decision uses strict validation (rejects legacy values).""" + + @pytest.mark.parametrize("action", ["deny", "steer", "observe"]) + def test_accepts_canonical_actions(self, action: str) -> None: + # Given: a control action payload with a canonical decision + # When: validating via Pydantic + ca = ControlAction.model_validate({"decision": action}) + + # Then: the decision is accepted as-is + assert ca.decision == action + + @pytest.mark.parametrize("legacy", ["allow", "warn", "log"]) + def test_rejects_legacy_actions(self, legacy: str) -> None: + # Given: a control action payload using a legacy decision value + # When / Then: Pydantic validation rejects it at the API boundary + with pytest.raises(ValidationError, match="Invalid action"): + ControlAction.model_validate({"decision": legacy}) + + def test_rejects_unknown_action(self) -> None: + # Given: a control action payload with an unknown decision + # When / Then: Pydantic validation rejects it + with pytest.raises(ValidationError, match="Invalid action"): + ControlAction.model_validate({"decision": "block"}) + + +# --------------------------------------------------------------------------- +# EventQueryRequest.actions (API input boundary — strict) +# --------------------------------------------------------------------------- + + +class TestEventQueryRequestValidation: + """EventQueryRequest.actions uses strict validation.""" + + def test_accepts_canonical_actions(self) -> None: + # Given: a query filter with all three canonical action values + # When: constructing the query request + query = EventQueryRequest(actions=["deny", "steer", "observe"]) + + # Then: all actions are accepted + assert query.actions == ["deny", "steer", "observe"] + + def test_deduplicates_actions(self) -> None: + # Given: a query filter with duplicate action values + # When: constructing the query request + query = EventQueryRequest(actions=["deny", "deny", "observe"]) + + # Then: duplicates are removed while preserving order + assert query.actions == ["deny", "observe"] + + @pytest.mark.parametrize("legacy", ["allow", "warn", "log"]) + def test_rejects_legacy_actions(self, legacy: str) -> None: + # Given: a query filter using a legacy action value + # When / Then: Pydantic validation rejects it at the API boundary + with pytest.raises(ValidationError, match="Invalid action"): + EventQueryRequest(actions=[legacy]) + + def test_rejects_unknown_action(self) -> None: + # Given: a query filter with an unknown action + # When / Then: Pydantic validation rejects it + with pytest.raises(ValidationError, match="Invalid action"): + EventQueryRequest(actions=["block"]) + + +# --------------------------------------------------------------------------- +# ControlMatch / ControlExecutionEvent (read path — lenient normalization) +# --------------------------------------------------------------------------- + + +class TestReadPathNormalization: + """Internal read-path models normalize legacy values from DB rows.""" + + @pytest.mark.parametrize("legacy,expected", [ + ("allow", "observe"), + ("warn", "observe"), + ("log", "observe"), + ("observe", "observe"), + ("deny", "deny"), + ("steer", "steer"), + ]) + def test_control_match_normalizes_legacy(self, legacy: str, expected: str) -> None: + # Given: a ControlMatch deserialized from a DB row with a legacy action + # When: constructing the model + match = ControlMatch( + control_id=1, + control_name="test", + action=legacy, result=EvaluatorResult(matched=True, confidence=0.9), - ), - lambda: ControlExecutionEvent( - trace_id="trace-123", - span_id="span-123", + ) + + # Then: the action is normalized to the canonical value + assert match.action == expected + + @pytest.mark.parametrize("legacy,expected", [ + ("allow", "observe"), + ("warn", "observe"), + ("log", "observe"), + ("observe", "observe"), + ("deny", "deny"), + ("steer", "steer"), + ]) + def test_control_execution_event_normalizes_legacy( + self, legacy: str, expected: str + ) -> None: + # Given: a ControlExecutionEvent deserialized from a historical event row + # When: constructing the model + event = ControlExecutionEvent( + trace_id="4bf92f3577b34da6a3ce929d0e0e4736", + span_id="00f067aa0ba902b7", agent_name="test-agent", - control_id=123, - control_name="pii-check", + control_id=1, + control_name="test", check_stage="pre", applies_to="llm_call", - action=invalid_action, + action=legacy, matched=True, confidence=0.9, - ), - lambda: EventQueryRequest(actions=[invalid_action]), - ] + ) + + # Then: the action is normalized to the canonical value + assert event.action == expected - for build_invalid_model in invalid_builders: - # When / Then: validation fails before the invalid action can enter the system + def test_control_match_rejects_unknown(self) -> None: + # Given: a ControlMatch with a completely unknown action + # When / Then: validation rejects it even on the lenient read path with pytest.raises(ValidationError, match="Invalid action"): - build_invalid_model() + ControlMatch( + control_id=1, + control_name="test", + action="block", + result=EvaluatorResult(matched=True, confidence=0.9), + ) + + +# --------------------------------------------------------------------------- +# expand_action_filter (internal query expansion) +# --------------------------------------------------------------------------- + + +class TestExpandActionFilter: + """expand_action_filter expands canonical actions for SQL queries against historical data.""" + + def test_observe_expands_to_include_legacy(self) -> None: + # Given: a canonical "observe" filter + # When: expanding for SQL WHERE clause against historical events + expanded = expand_action_filter(["observe"]) + + # Then: it includes all legacy advisory action values stored in old rows + assert expanded == ["observe", "allow", "warn", "log"] + + def test_deny_and_steer_do_not_expand(self) -> None: + # Given: deny and steer filters (no legacy aliases) + # When: expanding + # Then: they map only to themselves + assert expand_action_filter(["deny"]) == ["deny"] + assert expand_action_filter(["steer"]) == ["steer"] + + def test_full_expansion(self) -> None: + # Given: all three canonical actions + # When: expanding + expanded = expand_action_filter(["deny", "steer", "observe"]) + + # Then: deny and steer are unchanged, observe expands to include legacy + assert expanded == ["deny", "steer", "observe", "allow", "warn", "log"] diff --git a/sdks/python/tests/test_observability_updates.py b/sdks/python/tests/test_observability_updates.py index 1e7a1025..6d3c84f5 100644 --- a/sdks/python/tests/test_observability_updates.py +++ b/sdks/python/tests/test_observability_updates.py @@ -136,7 +136,7 @@ def _make_control(self, id, name, condition): control=ControlDefinition( execution="sdk", condition=condition, - action={"decision": "allow"}, + action={"decision": "observe"}, ), ) @@ -150,7 +150,7 @@ def _make_request(self, step_type="llm"): stage="pre", ) - def _make_match(self, control_id, control_name="ctrl", action="allow", matched=True): + def _make_match(self, control_id, control_name="ctrl", action="observe", matched=True): from agent_control_models import ControlMatch, EvaluatorResult return ControlMatch( @@ -264,7 +264,7 @@ def test_preserves_error_message_parity_by_result_category(self): ControlMatch( control_id=1, control_name="ctrl-1", - action="allow", + action="observe", result=EvaluatorResult( matched=True, confidence=0.9, @@ -276,7 +276,7 @@ def test_preserves_error_message_parity_by_result_category(self): ControlMatch( control_id=1, control_name="ctrl-1", - action="allow", + action="observe", result=EvaluatorResult(matched=False, confidence=0.2, error="eval-error"), ) ], @@ -284,7 +284,7 @@ def test_preserves_error_message_parity_by_result_category(self): ControlMatch( control_id=1, control_name="ctrl-1", - action="allow", + action="observe", result=EvaluatorResult(matched=False, confidence=0.1, error="ignored-error"), ) ], @@ -547,15 +547,15 @@ async def test_skips_local_event_reconstruction_when_observability_disabled(self controls = [{ "id": 1, "name": "local-ctrl", - "control": { - "condition": { - "evaluator": {"name": "regex", "config": {"pattern": "test"}}, - "selector": {"path": "input"}, + "control": { + "condition": { + "evaluator": {"name": "regex", "config": {"pattern": "test"}}, + "selector": {"path": "input"}, + }, + "action": {"decision": "observe"}, + "execution": "sdk", }, - "action": {"decision": "allow"}, - "execution": "sdk", - }, - }] + }] mock_response = EvaluationResponse(is_safe=True, confidence=1.0) mock_engine = MagicMock() @@ -633,7 +633,7 @@ async def test_merged_event_mode_enqueues_reconstructed_local_and_server_events_ ControlMatch( control_id=1, control_name="local-ctrl", - action="allow", + action="observe", result=EvaluatorResult(matched=False, confidence=0.8), ) ], @@ -645,7 +645,7 @@ async def test_merged_event_mode_enqueues_reconstructed_local_and_server_events_ { "control_id": 2, "control_name": "server-ctrl", - "action": "allow", + "action": "observe", "control_execution_id": "ce-server", "result": {"matched": False, "confidence": 0.4}, } @@ -663,7 +663,7 @@ async def test_merged_event_mode_enqueues_reconstructed_local_and_server_events_ "evaluator": {"name": "regex", "config": {"pattern": "test"}}, "selector": {"path": "input"}, }, - "action": {"decision": "allow"}, + "action": {"decision": "observe"}, "execution": "sdk", }, }, @@ -675,7 +675,7 @@ async def test_merged_event_mode_enqueues_reconstructed_local_and_server_events_ "evaluator": {"name": "regex", "config": {"pattern": "test"}}, "selector": {"path": "input"}, }, - "action": {"decision": "allow"}, + "action": {"decision": "observe"}, "execution": "server", }, }, @@ -724,7 +724,7 @@ async def test_merged_event_mode_enqueues_local_events_before_reraising_server_f ControlMatch( control_id=1, control_name="local-ctrl", - action="allow", + action="observe", result=EvaluatorResult(matched=False, confidence=0.8), ) ], @@ -739,7 +739,7 @@ async def test_merged_event_mode_enqueues_local_events_before_reraising_server_f "evaluator": {"name": "regex", "config": {"pattern": "test"}}, "selector": {"path": "input"}, }, - "action": {"decision": "allow"}, + "action": {"decision": "observe"}, "execution": "sdk", }, }, @@ -751,7 +751,7 @@ async def test_merged_event_mode_enqueues_local_events_before_reraising_server_f "evaluator": {"name": "regex", "config": {"pattern": "test"}}, "selector": {"path": "input"}, }, - "action": {"decision": "allow"}, + "action": {"decision": "observe"}, "execution": "server", }, }, @@ -796,7 +796,7 @@ async def test_merged_event_mode_enqueues_only_local_events_when_no_server_contr ControlMatch( control_id=1, control_name="local-ctrl", - action="allow", + action="observe", result=EvaluatorResult(matched=True, confidence=0.8), ) ], @@ -810,7 +810,7 @@ async def test_merged_event_mode_enqueues_only_local_events_when_no_server_contr "evaluator": {"name": "regex", "config": {"pattern": "test"}}, "selector": {"path": "input"}, }, - "action": {"decision": "allow"}, + "action": {"decision": "observe"}, "execution": "sdk", }, } diff --git a/server/tests/test_control_templates.py b/server/tests/test_control_templates.py index 0f8fd1c5..ab669c55 100644 --- a/server/tests/test_control_templates.py +++ b/server/tests/test_control_templates.py @@ -638,17 +638,17 @@ def test_template_backed_control_supports_direct_agent_attachment(client: TestCl assert body["matches"][0]["control_name"] == control_name -def test_template_backed_warn_control_evaluates_as_safe_with_observe_match( +def test_template_backed_observe_control_evaluates_as_safe_with_observe_match( client: TestClient, ) -> None: - # Given: a template-backed control whose rendered action uses the legacy warn alias + # Given: a template-backed control whose rendered action is observe payload = _template_payload() payload["template"] = deepcopy(payload["template"]) - payload["template"]["definition_template"]["action"]["decision"] = "warn" # type: ignore[index] + payload["template"]["definition_template"]["action"]["decision"] = "observe" # type: ignore[index] control_id, control_name = _create_template_control_with_name( client, payload, - name_prefix="warn-template-control", + name_prefix="observe-template-control", ) agent_name = _assign_control_to_agent(client, control_id) @@ -725,7 +725,7 @@ def test_template_backed_control_preserves_falsey_values_and_uses_them_in_behavi def test_mixed_raw_and_template_backed_controls_obey_deny_precedence( client: TestClient, ) -> None: - # Given: an agent with both a template-backed deny control and a raw warn control + # Given: an agent with both a template-backed deny control and a raw observe control template_control_id, template_control_name = _create_template_control_with_name(client) agent_name = _assign_control_to_agent(client, template_control_id) @@ -733,19 +733,19 @@ def test_mixed_raw_and_template_backed_controls_obey_deny_precedence( assert policy_response.status_code == 200, policy_response.text policy_id = policy_response.json()["policy_id"] - raw_warn_name = f"raw-warn-{uuid.uuid4()}" - raw_warn_response = client.put( + raw_observe_name = f"raw-observe-{uuid.uuid4()}" + raw_observe_response = client.put( "/api/v1/controls", json={ - "name": raw_warn_name, - "data": _raw_control_payload("hello", action="warn"), + "name": raw_observe_name, + "data": _raw_control_payload("hello", action="observe"), }, ) - assert raw_warn_response.status_code == 200, raw_warn_response.text - raw_warn_control_id = raw_warn_response.json()["control_id"] + assert raw_observe_response.status_code == 200, raw_observe_response.text + raw_observe_control_id = raw_observe_response.json()["control_id"] add_control_response = client.post( - f"/api/v1/policies/{policy_id}/controls/{raw_warn_control_id}" + f"/api/v1/policies/{policy_id}/controls/{raw_observe_control_id}" ) assert add_control_response.status_code == 200, add_control_response.text @@ -764,7 +764,7 @@ def test_mixed_raw_and_template_backed_controls_obey_deny_precedence( assert len(body["matches"]) == 2 names = {match["control_name"] for match in body["matches"]} actions = {match["action"] for match in body["matches"]} - assert names == {template_control_name, raw_warn_name} + assert names == {template_control_name, raw_observe_name} assert actions == {"deny", "observe"} diff --git a/server/tests/test_observability_models.py b/server/tests/test_observability_models.py index deb30fa1..4a7d9252 100644 --- a/server/tests/test_observability_models.py +++ b/server/tests/test_observability_models.py @@ -307,10 +307,20 @@ def test_filter_by_trace_id(self): assert query.trace_id == "4bf92f3577b34da6a3ce929d0e0e4736" def test_filter_by_actions(self): - """Test filtering by actions.""" - query = EventQueryRequest(actions=["deny", "warn"]) + """Test filtering by canonical actions.""" + # Given: a query with canonical action filter values + query = EventQueryRequest(actions=["deny", "observe"]) + + # Then: the actions are accepted as-is assert query.actions == ["deny", "observe"] + def test_filter_by_actions_rejects_legacy(self): + """Test that legacy action values are rejected in query filters.""" + # Given: a query filter that includes the legacy "warn" value + # When / Then: validation rejects it at the API boundary + with pytest.raises(ValidationError, match="Invalid action"): + EventQueryRequest(actions=["deny", "warn"]) + def test_limit_bounds(self): """Test limit bounds.""" with pytest.raises(ValidationError):