diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 29652d0..463acd0 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -2,9 +2,9 @@ name: Build on: push: - branches: [ "main" ] + branches: ["main"] pull_request: - branches: [ "main" ] + branches: ["main"] env: CARGO_TERM_COLOR: always @@ -32,6 +32,15 @@ jobs: - name: Setup Rust toolchain run: rustup toolchain install stable --profile minimal + - name: Install etcdctl + env: + ETCD_VER: v3.6.11 + run: | + curl -fsSL "https://github.com/etcd-io/etcd/releases/download/${ETCD_VER}/etcd-${ETCD_VER}-linux-amd64.tar.gz" -o /tmp/etcd-${ETCD_VER}-linux-amd64.tar.gz + tar xzf /tmp/etcd-${ETCD_VER}-linux-amd64.tar.gz -C /tmp + sudo install /tmp/etcd-${ETCD_VER}-linux-amd64/etcdctl /usr/local/bin/etcdctl + etcdctl version + - name: Setup environment run: sudo docker compose -f ci/docker-compose.yaml up -d diff --git a/Cargo.lock b/Cargo.lock index 7cf5ab5..e17edef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -72,6 +72,7 @@ dependencies = [ "pin-project", "pretty_assertions", "rand 0.10.1", + "regex", "reqwest", "rstest", "rust-embed", @@ -115,10 +116,12 @@ dependencies = [ "aws-smithy-runtime-api", "http 1.4.0", "percent-encoding", + "regex", "reqwest", "serde", "serde_json", "thiserror", + "tokio", "utoipa", ] diff --git a/Cargo.toml b/Cargo.toml index a41e461..48ac06f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,6 +38,7 @@ reqwest = { version = "0.13", default-features = false, features = [ "stream", "native-tls", ] } +regex = "1" serde = "1.0.228" serde_json = "1.0" thiserror = "2" @@ -63,6 +64,7 @@ tokio.workspace = true reqwest.workspace = true http.workspace = true axum.workspace = true +regex.workspace = true serde.workspace = true serde_json.workspace = true async-trait.workspace = true diff --git a/crates/aisix-guardrail/Cargo.toml b/crates/aisix-guardrail/Cargo.toml index 0cd8a7e..42e480f 100644 --- a/crates/aisix-guardrail/Cargo.toml +++ b/crates/aisix-guardrail/Cargo.toml @@ -11,8 +11,12 @@ aws-sigv4.workspace = true aws-smithy-runtime-api.workspace = true http.workspace = true percent-encoding.workspace = true +regex.workspace = true reqwest.workspace = true serde.workspace = true serde_json.workspace = true thiserror.workspace = true utoipa.workspace = true + +[dev-dependencies] +tokio.workspace = true diff --git a/crates/aisix-guardrail/src/guardrails/mod.rs b/crates/aisix-guardrail/src/guardrails/mod.rs index 979724b..baf655c 100644 --- a/crates/aisix-guardrail/src/guardrails/mod.rs +++ b/crates/aisix-guardrail/src/guardrails/mod.rs @@ -1,13 +1,18 @@ pub mod bedrock; +pub mod regex; -pub use bedrock::{BedrockGuardrailMeta, BedrockGuardrailRuntime}; +pub use self::{ + bedrock::{BedrockGuardrailMeta, BedrockGuardrailRuntime}, + regex::{RegexGuardrailMeta, RegexGuardrailRuntime}, +}; pub mod identifiers { - use super::bedrock; + use super::{bedrock, regex}; pub const BEDROCK: &str = bedrock::IDENTIFIER; + pub const REGEX: &str = regex::IDENTIFIER; } pub mod configs { - pub use super::bedrock::BedrockGuardrailConfig; + pub use super::{bedrock::BedrockGuardrailConfig, regex::RegexGuardrailConfig}; } diff --git a/crates/aisix-guardrail/src/guardrails/regex.rs b/crates/aisix-guardrail/src/guardrails/regex.rs new file mode 100644 index 0000000..fd3142e --- /dev/null +++ b/crates/aisix-guardrail/src/guardrails/regex.rs @@ -0,0 +1,240 @@ +use std::convert::Infallible; + +use async_trait::async_trait; +use regex::Regex; +use serde::{Deserialize, Deserializer, Serialize, de}; +use utoipa::ToSchema; + +use crate::traits::{ + GuardrailCheckPayload, GuardrailContentPart, GuardrailMessage, GuardrailMessageContent, + GuardrailMeta, GuardrailOutcome, GuardrailRuntime, +}; + +pub const IDENTIFIER: &str = "regex"; + +#[derive(Debug, Clone, Serialize, ToSchema)] +pub struct RegexGuardrailConfig { + pub pattern: String, + + #[serde(skip_serializing_if = "Option::is_none")] + pub block_reason: Option, + + #[serde(skip)] + #[schema(ignore)] + compiled_pattern: Regex, +} + +impl RegexGuardrailConfig { + pub fn new( + pattern: impl Into, + block_reason: Option, + ) -> Result { + let pattern = pattern.into(); + let compiled_pattern = Regex::new(&pattern)?; + + Ok(Self { + pattern, + block_reason, + compiled_pattern, + }) + } + + pub fn compiled_pattern(&self) -> &Regex { + &self.compiled_pattern + } +} + +impl<'de> Deserialize<'de> for RegexGuardrailConfig { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + struct RawRegexGuardrailConfig { + pattern: String, + #[serde(default)] + block_reason: Option, + } + + let raw = RawRegexGuardrailConfig::deserialize(deserializer)?; + + Self::new(raw.pattern, raw.block_reason) + .map_err(|error| de::Error::custom(format!("invalid regex guardrail pattern: {error}"))) + } +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct RegexGuardrailMeta; + +impl GuardrailMeta for RegexGuardrailMeta { + fn name(&self) -> &'static str { + IDENTIFIER + } +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct RegexGuardrailRuntime; + +impl RegexGuardrailRuntime { + pub fn new() -> Self { + Self + } +} + +impl GuardrailMeta for RegexGuardrailRuntime { + fn name(&self) -> &'static str { + IDENTIFIER + } +} + +#[async_trait] +impl GuardrailRuntime for RegexGuardrailRuntime { + type Error = Infallible; + + async fn check( + &self, + payload: &GuardrailCheckPayload, + config: &RegexGuardrailConfig, + ) -> Result { + if payload_matches(config.compiled_pattern(), payload) { + return Ok(GuardrailOutcome::Block { + reason: config + .block_reason + .clone() + .unwrap_or_else(|| "regex guardrail blocked".into()), + }); + } + + Ok(GuardrailOutcome::Allow) + } +} + +fn payload_matches(pattern: &Regex, payload: &GuardrailCheckPayload) -> bool { + let messages = match payload { + GuardrailCheckPayload::Input(payload) => &payload.messages, + GuardrailCheckPayload::Output(payload) => &payload.messages, + }; + + messages + .iter() + .any(|message| message_matches(pattern, message)) +} + +fn message_matches(pattern: &Regex, message: &GuardrailMessage) -> bool { + match &message.content { + Some(GuardrailMessageContent::Text(text)) => pattern.is_match(text), + Some(GuardrailMessageContent::Parts(parts)) => parts.iter().any(|part| match part { + GuardrailContentPart::Text { text } => pattern.is_match(text), + GuardrailContentPart::ImageUrl { .. } => false, + }), + None => false, + } +} + +#[cfg(test)] +mod tests { + use serde_json::json; + + use super::{RegexGuardrailConfig, RegexGuardrailRuntime}; + use crate::traits::{ + GuardrailCheckPayload, GuardrailContentPart, GuardrailImageUrl, GuardrailMessage, + GuardrailMessageContent, GuardrailOutcome, GuardrailRole, GuardrailRuntime, + InputGuardrailPayload, + }; + + fn config(pattern: &str) -> RegexGuardrailConfig { + RegexGuardrailConfig::new(pattern, Some("matched blocked content".into())).unwrap() + } + + fn runtime() -> RegexGuardrailRuntime { + RegexGuardrailRuntime::new() + } + + fn input_payload(content: GuardrailMessageContent) -> GuardrailCheckPayload { + GuardrailCheckPayload::Input(InputGuardrailPayload { + messages: vec![GuardrailMessage { + role: GuardrailRole::User, + content: Some(content), + name: None, + tool_calls: None, + tool_call_id: None, + }], + }) + } + + #[tokio::test] + async fn blocks_when_plain_text_matches_pattern() { + let outcome = runtime() + .check( + &input_payload(GuardrailMessageContent::Text( + "my secret token is 12345".into(), + )), + &config(r"secret token"), + ) + .await + .unwrap(); + + assert_eq!( + outcome, + GuardrailOutcome::Block { + reason: "matched blocked content".into(), + } + ); + } + + #[tokio::test] + async fn allows_when_no_message_text_matches_pattern() { + let outcome = runtime() + .check( + &input_payload(GuardrailMessageContent::Text("hello world".into())), + &config(r"secret token"), + ) + .await + .unwrap(); + + assert_eq!(outcome, GuardrailOutcome::Allow); + } + + #[tokio::test] + async fn matches_text_parts_and_ignores_non_text_parts() { + let outcome = runtime() + .check( + &input_payload(GuardrailMessageContent::Parts(vec![ + GuardrailContentPart::ImageUrl { + image_url: GuardrailImageUrl { + url: "https://example.com/cat.png".into(), + detail: Some("high".into()), + }, + }, + GuardrailContentPart::Text { + text: "contains credit card 4111111111111111".into(), + }, + ])), + &config(r"\b\d{16}\b"), + ) + .await + .unwrap(); + + assert_eq!( + outcome, + GuardrailOutcome::Block { + reason: "matched blocked content".into(), + } + ); + } + + #[test] + fn deserialize_rejects_invalid_patterns() { + let error = serde_json::from_value::(json!({ + "pattern": "[", + "block_reason": "matched blocked content" + })) + .unwrap_err(); + + assert!( + error + .to_string() + .contains("invalid regex guardrail pattern") + ); + } +} diff --git a/src/config/entities/guardrails-schema.json b/src/config/entities/guardrails-schema.json index 5d8269e..4bcaa84 100644 --- a/src/config/entities/guardrails-schema.json +++ b/src/config/entities/guardrails-schema.json @@ -5,7 +5,7 @@ "name": { "type": "string", "minLength": 1 }, "type": { "type": "string", - "enum": ["bedrock"] + "enum": ["bedrock", "regex"] }, "config": { "type": "object" } }, @@ -22,6 +22,17 @@ "config": { "$ref": "#/$defs/bedrock" } } } + }, + { + "if": { + "properties": { "type": { "const": "regex" } }, + "required": ["type"] + }, + "then": { + "properties": { + "config": { "$ref": "#/$defs/regex" } + } + } } ], "$defs": { @@ -44,6 +55,15 @@ "endpoint": { "type": "string", "minLength": 1 } }, "additionalProperties": false + }, + "regex": { + "type": "object", + "required": ["pattern"], + "properties": { + "pattern": { "type": "string", "minLength": 1 }, + "block_reason": { "type": "string", "minLength": 1 } + }, + "additionalProperties": false } } } diff --git a/src/config/entities/guardrails.rs b/src/config/entities/guardrails.rs index ea05924..2b06929 100644 --- a/src/config/entities/guardrails.rs +++ b/src/config/entities/guardrails.rs @@ -8,7 +8,10 @@ use utoipa::ToSchema; use super::{ConfigProvider, EntityStore, ResourceEntry}; use crate::{ - guardrail::guardrails::{configs::BedrockGuardrailConfig, identifiers}, + guardrail::guardrails::{ + configs::{BedrockGuardrailConfig, RegexGuardrailConfig}, + identifiers, + }, utils::jsonschema::format_evaluation_error, }; @@ -25,12 +28,16 @@ pub static SCHEMA_VALIDATOR: LazyLock = LazyLock::new(|| pub enum GuardrailConfig { #[serde(rename = "bedrock")] Bedrock(BedrockGuardrailConfig), + + #[serde(rename = "regex")] + Regex(RegexGuardrailConfig), } impl GuardrailConfig { pub fn guardrail_type(&self) -> &'static str { match self { Self::Bedrock(_) => identifiers::BEDROCK, + Self::Regex(_) => identifiers::REGEX, } } } @@ -62,9 +69,18 @@ fn validate(key: &str, value: &Guardrail) -> Result<(), String> { )); } + validate_config(key, &value.guardrail)?; + Ok(()) } +fn validate_config(_key: &str, config: &GuardrailConfig) -> Result<(), String> { + match config { + GuardrailConfig::Bedrock(_) => Ok(()), + GuardrailConfig::Regex(_) => Ok(()), + } +} + #[derive(Clone)] pub struct GuardrailsStore { store: EntityStore, @@ -212,6 +228,14 @@ mod tests { "secret_access_key": "secret" } }), true, None)] + #[case::regex_ok(json!({ + "name": "regex-prod", + "type": "regex", + "config": { + "pattern": "secret", + "block_reason": "matched blocked content" + } + }), true, None)] #[case::missing_type(json!({ "name": "bedrock-prod", "config": { @@ -252,6 +276,13 @@ mod tests { "secret_access_key": "secret" } }), false, Some(r#"property "/config" validation failed: "region" is a required property"#.to_string()))] + #[case::regex_missing_pattern(json!({ + "name": "regex-prod", + "type": "regex", + "config": { + "block_reason": "matched blocked content" + } + }), false, Some(r#"property "/config" validation failed: "pattern" is a required property"#.to_string()))] #[case::invalid_root_additional_property(json!({ "name": "bedrock-prod", "type": "bedrock", @@ -281,6 +312,25 @@ mod tests { } } + #[test] + fn deserialize_regex_guardrail_rejects_invalid_patterns() { + let error = serde_json::from_value::(json!({ + "name": "regex-invalid", + "type": "regex", + "config": { + "pattern": "[", + "block_reason": "matched blocked content" + } + })) + .expect_err("invalid regex pattern should be rejected while loading config"); + + assert!( + error + .to_string() + .contains("invalid regex guardrail pattern") + ); + } + #[test] fn deserialize_guardrail_preserves_type_information() { let guardrail: Guardrail = serde_json::from_value(json!({ @@ -301,6 +351,23 @@ mod tests { assert_matches!(guardrail.guardrail, GuardrailConfig::Bedrock(_)); } + #[test] + fn deserialize_regex_guardrail_preserves_type_information() { + let guardrail: Guardrail = serde_json::from_value(json!({ + "name": "regex-prod", + "type": "regex", + "config": { + "pattern": "secret", + "block_reason": "matched blocked content" + } + })) + .unwrap(); + + assert_eq!(guardrail.name, "regex-prod"); + assert_eq!(guardrail.guardrail_type(), "regex"); + assert_matches!(guardrail.guardrail, GuardrailConfig::Regex(_)); + } + #[tokio::test] async fn guardrails_store_loads_full_snapshot_with_relative_ids() { let (tx, rx) = mpsc::channel(8); diff --git a/src/proxy/guardrails.rs b/src/proxy/guardrails.rs index ab8c488..3c02175 100644 --- a/src/proxy/guardrails.rs +++ b/src/proxy/guardrails.rs @@ -1,5 +1,5 @@ use aisix_guardrail::{ - guardrails::BedrockGuardrailRuntime, + guardrails::{BedrockGuardrailRuntime, RegexGuardrailRuntime}, traits::{ GuardrailCheckPayload, GuardrailContentPart, GuardrailImageUrl, GuardrailMessage, GuardrailMessageContent, GuardrailOutcome, GuardrailRole, GuardrailRuntime, GuardrailStage, @@ -278,6 +278,10 @@ fn configured_guardrail_runtime_from_configs( BedrockGuardrailRuntime::new(), config.clone(), ))), + GuardrailConfig::Regex(config) => Ok(Box::new(GuardrailRuntimeHandle::new( + RegexGuardrailRuntime::new(), + config.clone(), + ))), } } @@ -353,7 +357,7 @@ mod tests { use std::sync::Mutex; use aisix_guardrail::{ - guardrails::configs::BedrockGuardrailConfig, + guardrails::configs::{BedrockGuardrailConfig, RegexGuardrailConfig}, traits::{ GuardrailCheckPayload, GuardrailContentPart, GuardrailMessage, GuardrailMessageContent, GuardrailMeta, GuardrailOutcome, GuardrailRole, GuardrailRuntime, GuardrailStage, @@ -643,6 +647,17 @@ mod tests { assert!(runtime.supports_stage(GuardrailStage::Input)); } + #[test] + fn configured_guardrail_runtime_from_configs_builds_regex_runtime() { + let runtime = configured_guardrail_runtime_from_configs(&GuardrailConfig::Regex( + RegexGuardrailConfig::new("secret", Some("matched blocked content".into())).unwrap(), + )) + .unwrap(); + + assert_eq!(runtime.name(), "regex"); + assert!(runtime.supports_stage(GuardrailStage::Output)); + } + #[test] fn input_payload_to_chat_messages_round_trips() { let payload = InputGuardrailPayload { diff --git a/tests/package.json b/tests/package.json index ccd9543..b400d53 100644 --- a/tests/package.json +++ b/tests/package.json @@ -8,7 +8,7 @@ "test": "vitest run", "test:dev": "vitest" }, - "packageManager": "pnpm@11.1.1+sha512.d1fdf5f73c617b64fa1a56a81c3c8dfe0e966e33a6010aa256b517ae77be21d93e05affc0de1a83b0e4f29d569f68b446ae8f068cd7247c0bb3df0fb4d7bdf9a", + "packageManager": "pnpm@11.1.2+sha512.415a1cc25974731e75455c1468371be74c5aa5fb7621b50d4056d222451609f11412f23fd602e6169f1e060466641f798597e1be961a10688836a67b16569499", "devDependencies": { "@anthropic-ai/sdk": "^0.88.0", "@eslint/js": "^10.0.1", diff --git a/tests/proxy/guardrail/chat-completions.test.ts b/tests/proxy/guardrail/chat-completions.test.ts new file mode 100644 index 0000000..3e33f09 --- /dev/null +++ b/tests/proxy/guardrail/chat-completions.test.ts @@ -0,0 +1,109 @@ +import { proxyPost } from '../../utils/proxy.js'; +import { + type RegexGuardrailFixture, + setupOpenAiRegexGuardrailFixture, +} from './shared.js'; + +const ADMIN_KEY = 'test_admin_key_guardrail_chat_completions'; +const AUTHORIZED_KEY = 'sk-proxy-guardrail-chat-completions'; +const UPSTREAM_API_KEY = 'upstream-key-guardrail-chat-completions'; +const UPSTREAM_MODEL = 'test-model'; + +describe('proxy guardrail /v1/chat/completions', () => { + let fixture: RegexGuardrailFixture | undefined; + + beforeEach(async () => { + fixture = await setupOpenAiRegexGuardrailFixture({ + adminKey: ADMIN_KEY, + authorizedKey: AUTHORIZED_KEY, + upstreamApiKey: UPSTREAM_API_KEY, + upstreamModel: UPSTREAM_MODEL, + modelPrefix: 'mock-chat-guardrail', + }); + }, 30_000); + + afterEach(async () => { + await fixture?.close(); + }); + + test('input regex guardrail blocks request before upstream call', async () => { + const resp = await proxyPost( + '/v1/chat/completions', + { + model: fixture?.inputGuardedModelName, + messages: [{ role: 'user', content: 'my secret token is 12345' }], + }, + AUTHORIZED_KEY, + ); + + expect(resp.status).toBe(400); + expect(resp.data.error.code).toBe('gateway_error'); + expect(resp.data.error.type).toBe('invalid_request_error'); + expect(resp.data.error.message).toContain('guardrail regex blocked input'); + expect(resp.data.error.message).toContain( + 'blocked by regex input guardrail', + ); + + const recorded = fixture?.upstream.takeRecordedRequests() ?? []; + expect(recorded).toHaveLength(0); + }); + + test('input regex guardrail allows safe request through to upstream', async () => { + const resp = await proxyPost( + '/v1/chat/completions', + { + model: fixture?.inputGuardedModelName, + messages: [ + { role: 'user', content: 'safe request through regex guardrail' }, + ], + }, + AUTHORIZED_KEY, + ); + + expect(resp.status).toBe(200); + expect(resp.data.choices[0].message.content).toBe( + 'hello from mock upstream', + ); + + const recorded = fixture?.upstream.takeRecordedRequests() ?? []; + expect(recorded).toHaveLength(1); + expect( + ( + recorded[0]?.bodyJson as { + messages: Array<{ content: string }>; + } + ).messages[0]?.content, + ).toBe('safe request through regex guardrail'); + }); + + test('output regex guardrail blocks matched upstream response', async () => { + const resp = await proxyPost( + '/v1/chat/completions', + { + model: fixture?.outputGuardedModelName, + messages: [ + { role: 'user', content: 'safe prompt for output guardrail' }, + ], + }, + AUTHORIZED_KEY, + ); + + expect(resp.status).toBe(400); + expect(resp.data.error.code).toBe('gateway_error'); + expect(resp.data.error.type).toBe('invalid_request_error'); + expect(resp.data.error.message).toContain('guardrail regex blocked output'); + expect(resp.data.error.message).toContain( + 'blocked by regex output guardrail', + ); + + const recorded = fixture?.upstream.takeRecordedRequests() ?? []; + expect(recorded).toHaveLength(1); + expect( + ( + recorded[0]?.bodyJson as { + messages: Array<{ content: string }>; + } + ).messages[0]?.content, + ).toBe('safe prompt for output guardrail'); + }); +}); diff --git a/tests/proxy/guardrail/messages.test.ts b/tests/proxy/guardrail/messages.test.ts new file mode 100644 index 0000000..68279a5 --- /dev/null +++ b/tests/proxy/guardrail/messages.test.ts @@ -0,0 +1,109 @@ +import { buildOpenAiProviderModel } from '../../utils/mock-upstream.js'; +import { proxyPost } from '../../utils/proxy.js'; +import { + type RegexGuardrailFixture, + setupOpenAiRegexGuardrailFixture, +} from './shared.js'; + +const ADMIN_KEY = 'test_admin_key_guardrail_messages'; +const AUTHORIZED_KEY = 'sk-proxy-guardrail-messages'; +const UPSTREAM_API_KEY = 'upstream-key-guardrail-messages'; +const UPSTREAM_MODEL = 'test-model'; + +describe('proxy guardrail /v1/messages', () => { + let fixture: RegexGuardrailFixture | undefined; + + beforeEach(async () => { + fixture = await setupOpenAiRegexGuardrailFixture({ + adminKey: ADMIN_KEY, + authorizedKey: AUTHORIZED_KEY, + upstreamApiKey: UPSTREAM_API_KEY, + upstreamModel: UPSTREAM_MODEL, + modelPrefix: 'mock-messages-guardrail', + buildModel: buildOpenAiProviderModel, + }); + }, 30_000); + + afterEach(async () => { + await fixture?.close(); + }); + + test('input regex guardrail blocks messages request before upstream call', async () => { + const resp = await proxyPost( + '/v1/messages', + { + model: fixture?.inputGuardedModelName, + max_tokens: 256, + messages: [{ role: 'user', content: 'my secret token is 12345' }], + }, + AUTHORIZED_KEY, + ); + + expect(resp.status).toBe(400); + expect(resp.data.type).toBe('error'); + expect(resp.data.error.type).toBe('invalid_request_error'); + expect(resp.data.error.message).toBe('Invalid request'); + expect(typeof resp.data.request_id).toBe('string'); + + const recorded = fixture?.upstream.takeRecordedRequests() ?? []; + expect(recorded).toHaveLength(0); + }); + + test('input regex guardrail allows safe messages request through to upstream', async () => { + const resp = await proxyPost( + '/v1/messages', + { + model: fixture?.inputGuardedModelName, + max_tokens: 256, + messages: [ + { role: 'user', content: 'safe request through regex guardrail' }, + ], + }, + AUTHORIZED_KEY, + ); + + expect(resp.status).toBe(200); + expect(resp.data.type).toBe('message'); + expect(resp.data.content[0].text).toBe('hello from mock upstream'); + + const recorded = fixture?.upstream.takeRecordedRequests() ?? []; + expect(recorded).toHaveLength(1); + expect( + ( + recorded[0]?.bodyJson as { + messages: Array<{ content: string }>; + } + ).messages[0]?.content, + ).toBe('safe request through regex guardrail'); + }); + + test('output regex guardrail blocks matched messages response', async () => { + const resp = await proxyPost( + '/v1/messages', + { + model: fixture?.outputGuardedModelName, + max_tokens: 256, + messages: [ + { role: 'user', content: 'safe prompt for output guardrail' }, + ], + }, + AUTHORIZED_KEY, + ); + + expect(resp.status).toBe(400); + expect(resp.data.type).toBe('error'); + expect(resp.data.error.type).toBe('invalid_request_error'); + expect(resp.data.error.message).toBe('Invalid request'); + expect(typeof resp.data.request_id).toBe('string'); + + const recorded = fixture?.upstream.takeRecordedRequests() ?? []; + expect(recorded).toHaveLength(1); + expect( + ( + recorded[0]?.bodyJson as { + messages: Array<{ content: string }>; + } + ).messages[0]?.content, + ).toBe('safe prompt for output guardrail'); + }); +}); diff --git a/tests/proxy/guardrail/responses.test.ts b/tests/proxy/guardrail/responses.test.ts new file mode 100644 index 0000000..6a60f77 --- /dev/null +++ b/tests/proxy/guardrail/responses.test.ts @@ -0,0 +1,106 @@ +import { proxyPost } from '../../utils/proxy.js'; +import { + type RegexGuardrailFixture, + setupOpenAiRegexGuardrailFixture, +} from './shared.js'; + +const ADMIN_KEY = 'test_admin_key_guardrail_responses'; +const AUTHORIZED_KEY = 'sk-proxy-guardrail-responses'; +const UPSTREAM_API_KEY = 'upstream-key-guardrail-responses'; +const UPSTREAM_MODEL = 'test-model'; + +describe('proxy guardrail /v1/responses', () => { + let fixture: RegexGuardrailFixture | undefined; + + beforeEach(async () => { + fixture = await setupOpenAiRegexGuardrailFixture({ + adminKey: ADMIN_KEY, + authorizedKey: AUTHORIZED_KEY, + upstreamApiKey: UPSTREAM_API_KEY, + upstreamModel: UPSTREAM_MODEL, + modelPrefix: 'mock-responses-guardrail', + }); + }, 30_000); + + afterEach(async () => { + await fixture?.close(); + }); + + test('input regex guardrail blocks responses request before upstream call', async () => { + const resp = await proxyPost( + '/v1/responses', + { + model: fixture?.inputGuardedModelName, + input: 'my secret token is 12345', + }, + AUTHORIZED_KEY, + ); + + expect(resp.status).toBe(400); + expect(resp.data.error.code).toBe('gateway_error'); + expect(resp.data.error.type).toBe('invalid_request_error'); + expect(resp.data.error.message).toContain('guardrail regex blocked input'); + expect(resp.data.error.message).toContain( + 'blocked by regex input guardrail', + ); + + const recorded = fixture?.upstream.takeRecordedRequests() ?? []; + expect(recorded).toHaveLength(0); + }); + + test('input regex guardrail allows safe responses request through to upstream', async () => { + const resp = await proxyPost( + '/v1/responses', + { + model: fixture?.inputGuardedModelName, + input: 'safe request through regex guardrail', + }, + AUTHORIZED_KEY, + ); + + expect(resp.status).toBe(200); + expect(resp.data.object).toBe('response'); + expect(resp.data.output[0].content[0].text).toBe( + 'hello from mock upstream', + ); + + const recorded = fixture?.upstream.takeRecordedRequests() ?? []; + expect(recorded).toHaveLength(1); + expect( + ( + recorded[0]?.bodyJson as { + messages: Array<{ content: string }>; + } + ).messages[0]?.content, + ).toBe('safe request through regex guardrail'); + }); + + test('output regex guardrail blocks matched responses output', async () => { + const resp = await proxyPost( + '/v1/responses', + { + model: fixture?.outputGuardedModelName, + input: 'safe prompt for output guardrail', + }, + AUTHORIZED_KEY, + ); + + expect(resp.status).toBe(400); + expect(resp.data.error.code).toBe('gateway_error'); + expect(resp.data.error.type).toBe('invalid_request_error'); + expect(resp.data.error.message).toContain('guardrail regex blocked output'); + expect(resp.data.error.message).toContain( + 'blocked by regex output guardrail', + ); + + const recorded = fixture?.upstream.takeRecordedRequests() ?? []; + expect(recorded).toHaveLength(1); + expect( + ( + recorded[0]?.bodyJson as { + messages: Array<{ content: string }>; + } + ).messages[0]?.content, + ).toBe('safe prompt for output guardrail'); + }); +}); diff --git a/tests/proxy/guardrail/shared.ts b/tests/proxy/guardrail/shared.ts new file mode 100644 index 0000000..ba355d9 --- /dev/null +++ b/tests/proxy/guardrail/shared.ts @@ -0,0 +1,167 @@ +import { randomUUID } from 'node:crypto'; + +import { + MODELS_URL, + PROVIDERS_URL, + adminPost, + adminPut, + bearerAuthHeader, + startIsolatedAdminApp, +} from '../../utils/admin.js'; +import { etcdPutJson } from '../../utils/etcd.js'; +import { + type OpenAiMockUpstream, + buildOpenAiProviderConfig, + startOpenAiMockUpstream, +} from '../../utils/mock-upstream.js'; +import { App } from '../../utils/setup.js'; + +const waitConfigPropagation = async () => { + await new Promise((resolve) => setTimeout(resolve, 1000)); +}; + +const ensureStatus = ( + response: { status: number; data?: unknown }, + expectedStatus: number, + context: string, +) => { + if (response.status !== expectedStatus) { + throw new Error( + `${context} failed: expected ${expectedStatus}, got ${response.status}: ${JSON.stringify(response.data)}`, + ); + } +}; + +export interface RegexGuardrailFixture { + server: App; + upstream: OpenAiMockUpstream; + inputGuardedModelName: string; + outputGuardedModelName: string; + close: () => Promise; +} + +interface SetupOpenAiRegexGuardrailFixtureOptions { + adminKey: string; + authorizedKey: string; + upstreamApiKey: string; + upstreamModel: string; + modelPrefix: string; + buildModel?: (model: string) => string; +} + +export const setupOpenAiRegexGuardrailFixture = async ({ + adminKey, + authorizedKey, + upstreamApiKey, + upstreamModel, + modelPrefix, + buildModel = (model) => model, +}: SetupOpenAiRegexGuardrailFixtureOptions): Promise => { + const etcdPrefix = `/ai-admin-${randomUUID()}`; + let server: App | undefined; + let upstream: OpenAiMockUpstream | undefined; + + try { + server = await startIsolatedAdminApp(adminKey, etcdPrefix); + upstream = await startOpenAiMockUpstream(); + const auth = bearerAuthHeader(adminKey); + + const inputGuardedModelName = `${modelPrefix}-input-${randomUUID()}`; + const outputGuardedModelName = `${modelPrefix}-output-${randomUUID()}`; + const providerId = `${modelPrefix}-provider-${randomUUID()}`; + const inputGuardrailId = `${modelPrefix}-regex-input-${randomUUID()}`; + const outputGuardrailId = `${modelPrefix}-regex-output-${randomUUID()}`; + + ensureStatus( + await adminPut( + `${PROVIDERS_URL}/${providerId}`, + { + name: providerId, + type: 'openai', + config: buildOpenAiProviderConfig(upstream.apiBase, upstreamApiKey), + }, + auth, + ), + 201, + 'create provider', + ); + + await etcdPutJson(etcdPrefix, `/guardrails/${inputGuardrailId}`, { + name: `${modelPrefix}-regex-input`, + type: 'regex', + config: { + pattern: 'secret token', + block_reason: 'blocked by regex input guardrail', + }, + }); + + await etcdPutJson(etcdPrefix, `/guardrails/${outputGuardrailId}`, { + name: `${modelPrefix}-regex-output`, + type: 'regex', + config: { + pattern: 'hello from mock upstream', + block_reason: 'blocked by regex output guardrail', + }, + }); + + ensureStatus( + await adminPost( + MODELS_URL, + { + name: inputGuardedModelName, + model: buildModel(upstreamModel), + provider_id: providerId, + guardrail_ids: [inputGuardrailId], + }, + auth, + ), + 201, + 'create input-guarded model', + ); + + ensureStatus( + await adminPost( + MODELS_URL, + { + name: outputGuardedModelName, + model: buildModel(upstreamModel), + provider_id: providerId, + guardrail_ids: [outputGuardrailId], + }, + auth, + ), + 201, + 'create output-guarded model', + ); + + ensureStatus( + await adminPost( + '/apikeys', + { + key: authorizedKey, + allowed_models: [inputGuardedModelName, outputGuardedModelName], + }, + auth, + ), + 201, + 'create apikey', + ); + + await waitConfigPropagation(); + + return { + server, + upstream, + inputGuardedModelName, + outputGuardedModelName, + close: async () => { + await upstream?.close(); + await server?.exit(); + }, + }; + } catch (error) { + await upstream?.close(); + await server?.exit(); + throw error; + } +}; diff --git a/tests/utils/admin.ts b/tests/utils/admin.ts index 5c71d8e..9509e63 100644 --- a/tests/utils/admin.ts +++ b/tests/utils/admin.ts @@ -25,13 +25,16 @@ export const extractIdFromStorageKey = (storageKey: string) => { return id; }; -export const startIsolatedAdminApp = async (adminKey: string) => { +export const startIsolatedAdminApp = async ( + adminKey: string, + etcdPrefix = `/ai-admin-${randomUUID()}`, +) => { return (await ( await App.spawn( defaultConfig({ deployment: { etcd: { - prefix: `/ai-admin-${randomUUID()}`, + prefix: etcdPrefix, }, admin: { admin_key: [{ key: adminKey }] }, }, diff --git a/tests/utils/etcd.ts b/tests/utils/etcd.ts new file mode 100644 index 0000000..4d54d9d --- /dev/null +++ b/tests/utils/etcd.ts @@ -0,0 +1,27 @@ +import { execFile } from 'node:child_process'; +import { promisify } from 'node:util'; + +const execFileAsync = promisify(execFile); + +const ETCDCTL = process.env.ETCDCTL_PATH ?? 'etcdctl'; +const ETCD_ENDPOINTS = process.env.ETCD_ENDPOINTS ?? 'http://127.0.0.1:2379'; + +export const etcdPutJson = async ( + prefix: string, + path: string, + value: unknown, +) => { + await execFileAsync( + ETCDCTL, + [ + '--endpoints', + ETCD_ENDPOINTS, + '--dial-timeout=5s', + '--command-timeout=5s', + 'put', + `${prefix}${path}`, + JSON.stringify(value), + ], + { timeout: 10_000 }, + ); +};