Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions src/config/entities/models-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@
"name": { "type": "string", "minLength": 1 },
"provider_id": { "type": "string", "minLength": 1 },
"model": { "type": "string", "minLength": 1 },
"guardrail_ids": {
"type": "array",
"items": { "type": "string", "minLength": 1 }
},
"timeout": {
"type": "integer",
"minimum": 0
Expand Down
27 changes: 4 additions & 23 deletions src/config/entities/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,6 @@ pub struct Model {
pub provider_id: String,
pub model: String,

// Temporary binding surface for guardrail runtime wiring until policy evaluation attaches
// guardrails dynamically. Keeping this on Model lets tests and the current runtime path stay
// simple without committing to the long-term control-plane shape.
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub guardrail_ids: Vec<String>,

#[serde(skip_serializing_if = "Option::is_none")]
pub timeout: Option<u64>,

Expand Down Expand Up @@ -115,12 +109,6 @@ mod tests {
"provider_id": "openai-primary",
"model": "gpt-5"
}), true, None)]
#[case::ok_with_guardrails(json!({
"name": "test",
"provider_id": "bedrock-primary",
"model": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"guardrail_ids": ["gr-input", "gr-output"]
}), true, None)]
#[case::ok_with_rate_limit(json!({
"name": "test",
"provider_id": "bedrock-primary",
Expand Down Expand Up @@ -154,18 +142,12 @@ mod tests {
"provider_id": "openai-primary",
"model": 123
}), false, Some(r#"property "/model" validation failed: 123 is not of type "string""#.to_string()))]
#[case::invalid_guardrail_ids_type(json!({
"name": "test",
"provider_id": "openai-primary",
"model": "gpt-5",
"guardrail_ids": "gr-input"
}), false, Some(r#"property "/guardrail_ids" validation failed: "gr-input" is not of type "array""#.to_string()))]
#[case::invalid_guardrail_ids_element_type(json!({
#[case::legacy_guardrail_ids_rejected_by_schema(json!({
"name": "test",
"provider_id": "openai-primary",
"model": "gpt-5",
"guardrail_ids": [1]
}), false, Some(r#"property "/guardrail_ids/0" validation failed: 1 is not of type "string""#.to_string()))]
"guardrail_ids": ["gr-input"]
}), false, Some(r#"property "/" validation failed: Additional properties are not allowed ('guardrail_ids' was unexpected)"#.to_string()))]
#[case::invalid_root_additional_property(json!({
"name": "test",
"provider_id": "openai-primary",
Expand All @@ -190,7 +172,7 @@ mod tests {
}

#[test]
fn deserialize_model_preserves_provider_reference_and_model_name() {
fn deserialize_model_ignores_legacy_guardrail_ids() {
let model: super::Model = serde_json::from_value(json!({
"name": "test",
"provider_id": "bedrock-primary",
Expand All @@ -206,7 +188,6 @@ mod tests {
model.model,
"arn:aws:bedrock:us-east-1:123456789012:inference-profile/us.anthropic.claude-3-5-sonnet-20240620-v1:0"
);
assert_eq!(model.guardrail_ids, vec!["gr-input"]);
assert_eq!(model.timeout, Some(30000));
}
}
84 changes: 38 additions & 46 deletions src/proxy/guardrails.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use thiserror::Error;
pub(crate) mod streaming;

use crate::{
config::entities::{Model, ResourceEntry, ResourceRegistry, guardrails::GuardrailConfig},
config::entities::guardrails::GuardrailConfig,
gateway::{
error::GatewayError,
types::openai::{
Expand Down Expand Up @@ -42,7 +42,7 @@ where
}

#[async_trait]
pub(crate) trait ConfiguredGuardrailRuntime: Send + Sync {
pub(crate) trait ResolvedGuardrail: Send + Sync {
fn name(&self) -> &'static str;

fn supports_stage(&self, stage: GuardrailStage) -> bool;
Expand All @@ -53,19 +53,24 @@ pub(crate) trait ConfiguredGuardrailRuntime: Send + Sync {
) -> Result<Option<GuardrailOutcome>, GatewayError>;
}

struct GuardrailRuntimeHandle<R, C> {
struct RuntimeResolvedGuardrail<R, C> {
runtime: R,
config: C,
stage: GuardrailStage,
}

impl<R, C> GuardrailRuntimeHandle<R, C> {
fn new(runtime: R, config: C) -> Self {
Self { runtime, config }
impl<R, C> RuntimeResolvedGuardrail<R, C> {
fn new(runtime: R, config: C, stage: GuardrailStage) -> Self {
Self {
runtime,
config,
stage,
}
}
}

#[async_trait]
impl<R, C> ConfiguredGuardrailRuntime for GuardrailRuntimeHandle<R, C>
impl<R, C> ResolvedGuardrail for RuntimeResolvedGuardrail<R, C>
where
R: GuardrailRuntime<C> + Send + Sync,
C: Send + Sync,
Expand All @@ -76,7 +81,7 @@ where
}

fn supports_stage(&self, stage: GuardrailStage) -> bool {
self.runtime.supports_stage(stage)
self.stage == stage && self.runtime.supports_stage(stage)
}

async fn check(
Expand Down Expand Up @@ -213,29 +218,6 @@ pub(crate) fn output_payload_from_check_payload(
}
}

pub(crate) fn resolve_model_guardrails(
model: &ResourceEntry<Model>,
resources: &ResourceRegistry,
) -> Result<Vec<Box<dyn ConfiguredGuardrailRuntime>>, GatewayError> {
// This direct Model -> guardrail lookup is intentionally temporary. The long-term attachment
// point should come from policy evaluation so request-time guardrail selection is not encoded
// in the model resource itself.
model
.guardrail_ids
.iter()
.map(|guardrail_id| {
let guardrail = resources
.guardrails
.get_by_id(guardrail_id)
.ok_or_else(|| {
GatewayError::Internal(format!("guardrail {} not found", guardrail_id))
})?;

configured_guardrail_runtime_from_configs(&guardrail.guardrail)
})
.collect()
}

#[cfg(test)]
pub(crate) async fn run_guardrail_check<R, C>(
runtime: &R,
Expand Down Expand Up @@ -272,17 +254,20 @@ where
run_guardrail_check(runtime, config, payload).await
}

fn configured_guardrail_runtime_from_configs(
pub(crate) fn build_resolved_guardrail_for_stage(
guardrail: &GuardrailConfig,
) -> Result<Box<dyn ConfiguredGuardrailRuntime>, GatewayError> {
stage: GuardrailStage,
) -> Result<Box<dyn ResolvedGuardrail>, GatewayError> {
match guardrail {
GuardrailConfig::Bedrock(config) => Ok(Box::new(GuardrailRuntimeHandle::new(
GuardrailConfig::Bedrock(config) => Ok(Box::new(RuntimeResolvedGuardrail::new(
BedrockGuardrailRuntime::new(),
config.clone(),
stage,
))),
GuardrailConfig::Regex(config) => Ok(Box::new(GuardrailRuntimeHandle::new(
GuardrailConfig::Regex(config) => Ok(Box::new(RuntimeResolvedGuardrail::new(
RegexGuardrailRuntime::new(),
config.clone(),
stage,
))),
}
}
Comment thread
bzp2010 marked this conversation as resolved.
Expand Down Expand Up @@ -370,8 +355,8 @@ mod tests {
use thiserror::Error;

use super::{
GuardrailBridgeError, chat_message_to_guardrail_message,
configured_guardrail_runtime_from_configs, guardrail_message_to_chat_message,
GuardrailBridgeError, build_resolved_guardrail_for_stage,
chat_message_to_guardrail_message, guardrail_message_to_chat_message,
input_guardrail_payload_from_chat_messages, input_payload_from_check_payload,
input_payload_to_chat_messages, output_guardrail_payload_from_chat_messages,
output_payload_from_check_payload, output_payload_to_chat_messages,
Expand Down Expand Up @@ -631,33 +616,40 @@ mod tests {
}

#[test]
fn configured_guardrail_runtime_from_configs_builds_bedrock_runtime() {
let runtime = configured_guardrail_runtime_from_configs(&GuardrailConfig::Bedrock(
BedrockGuardrailConfig {
fn build_resolved_guardrail_for_stage_builds_bedrock_runtime() {
let runtime = build_resolved_guardrail_for_stage(
&GuardrailConfig::Bedrock(BedrockGuardrailConfig {
identifier: "guardrail-123".into(),
version: "1".into(),
region: "us-east-1".into(),
access_key_id: "AKIA123".into(),
secret_access_key: "secret".into(),
session_token: None,
endpoint: None,
},
))
}),
GuardrailStage::Input,
)
.unwrap();

assert_eq!(runtime.name(), "bedrock");
assert!(runtime.supports_stage(GuardrailStage::Input));
assert!(!runtime.supports_stage(GuardrailStage::Output));
}

#[test]
fn configured_guardrail_runtime_from_configs_builds_regex_runtime() {
let runtime = configured_guardrail_runtime_from_configs(&GuardrailConfig::Regex(
RegexGuardrailConfig::new("secret", Some("matched blocked content".into())).unwrap(),
))
fn build_resolved_guardrail_for_stage_builds_regex_runtime() {
let runtime = build_resolved_guardrail_for_stage(
&GuardrailConfig::Regex(
RegexGuardrailConfig::new("secret", Some("matched blocked content".into()))
.unwrap(),
),
GuardrailStage::Output,
)
.unwrap();

assert_eq!(runtime.name(), "regex");
assert!(runtime.supports_stage(GuardrailStage::Output));
assert!(!runtime.supports_stage(GuardrailStage::Input));
}

#[test]
Expand Down
Loading