diff --git a/Cargo.lock b/Cargo.lock index d960a85..7cf5ab5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -110,7 +110,15 @@ name = "aisix-guardrail" version = "0.1.0" dependencies = [ "async-trait", + "aws-credential-types", + "aws-sigv4", + "aws-smithy-runtime-api", + "http 1.4.0", + "percent-encoding", + "reqwest", "serde", + "serde_json", + "thiserror", "utoipa", ] diff --git a/crates/aisix-guardrail/Cargo.toml b/crates/aisix-guardrail/Cargo.toml index 4a071ca..0cd8a7e 100644 --- a/crates/aisix-guardrail/Cargo.toml +++ b/crates/aisix-guardrail/Cargo.toml @@ -6,5 +6,13 @@ rust-version.workspace = true [dependencies] async-trait.workspace = true +aws-credential-types.workspace = true +aws-sigv4.workspace = true +aws-smithy-runtime-api.workspace = true +http.workspace = true +percent-encoding.workspace = true +reqwest.workspace = true serde.workspace = true +serde_json.workspace = true +thiserror.workspace = true utoipa.workspace = true diff --git a/crates/aisix-guardrail/src/guardrails/bedrock.rs b/crates/aisix-guardrail/src/guardrails/bedrock.rs new file mode 100644 index 0000000..4319bd7 --- /dev/null +++ b/crates/aisix-guardrail/src/guardrails/bedrock.rs @@ -0,0 +1,635 @@ +use std::{fmt, time::SystemTime}; + +use async_trait::async_trait; +use aws_credential_types::Credentials; +use aws_sigv4::{ + http_request::{SignableBody, SignableRequest, SigningSettings, sign}, + sign::v4, +}; +use aws_smithy_runtime_api::client::identity::Identity; +use http::{HeaderMap, HeaderValue, Method, Request, StatusCode, header::CONTENT_TYPE}; +use percent_encoding::{AsciiSet, CONTROLS, utf8_percent_encode}; +use reqwest::{Client, Url}; +use serde::{Deserialize, Serialize}; +use thiserror::Error; +use utoipa::ToSchema; + +use crate::traits::{ + GuardrailCheckPayload, GuardrailMessage, GuardrailMessageContent, GuardrailMeta, + GuardrailOutcome, GuardrailRuntime, GuardrailStage, +}; + +pub const IDENTIFIER: &str = "bedrock"; + +const DEFAULT_OUTPUT_SCOPE: &str = "INTERVENTIONS"; +const DEFAULT_RUNTIME_HOST_PREFIX: &str = "https://bedrock-runtime."; +const DEFAULT_RUNTIME_HOST_SUFFIX: &str = ".amazonaws.com"; +type EncodedApplyGuardrailRequest = (Vec, Vec); + +const PATH_SEGMENT_ENCODE_SET: &AsciiSet = &CONTROLS + .add(b' ') + .add(b'"') + .add(b'#') + .add(b'%') + .add(b'<') + .add(b'>') + .add(b'?') + .add(b'`') + .add(b'{') + .add(b'}') + .add(b'/'); + +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +pub struct BedrockGuardrailConfig { + pub identifier: String, + pub version: String, +} + +#[derive(Clone, Serialize, Deserialize, ToSchema)] +pub struct BedrockRuntimeConfig { + pub region: String, + pub access_key_id: String, + pub secret_access_key: String, + + #[serde(skip_serializing_if = "Option::is_none")] + pub session_token: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub endpoint: Option, +} + +impl fmt::Debug for BedrockRuntimeConfig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + const REDACTED: &str = "[REDACTED]"; + + f.debug_struct("BedrockRuntimeConfig") + .field("region", &self.region) + .field("access_key_id", &REDACTED) + .field("secret_access_key", &REDACTED) + .field( + "session_token", + &self.session_token.as_ref().map(|_| REDACTED), + ) + .field("endpoint", &self.endpoint) + .finish() + } +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct BedrockGuardrailMeta; + +impl GuardrailMeta for BedrockGuardrailMeta { + fn name(&self) -> &'static str { + IDENTIFIER + } +} + +#[derive(Debug, Clone)] +pub struct BedrockGuardrailRuntime { + client: Client, + runtime: BedrockRuntimeConfig, +} + +impl BedrockGuardrailRuntime { + pub fn new(runtime: BedrockRuntimeConfig) -> Self { + Self { + client: Client::new(), + runtime, + } + } + + pub fn with_client(client: Client, runtime: BedrockRuntimeConfig) -> Self { + Self { client, runtime } + } + + fn build_apply_url(&self, guardrail: &BedrockGuardrailConfig) -> Result { + let base_url = self + .runtime + .endpoint + .as_deref() + .map(str::to_owned) + .unwrap_or_else(|| { + format!( + "{DEFAULT_RUNTIME_HOST_PREFIX}{}{DEFAULT_RUNTIME_HOST_SUFFIX}", + self.runtime.region + ) + }); + let endpoint_path = format!( + "/guardrail/{}/version/{}/apply", + encode_path_segment(&guardrail.identifier), + encode_path_segment(&guardrail.version), + ); + + build_url_for_endpoint(&base_url, &endpoint_path) + } + + fn build_request_body( + &self, + payload: &GuardrailCheckPayload, + ) -> Result, BedrockError> { + let (content, message_indexes) = text_blocks_from_payload(payload); + if content.is_empty() { + return Ok(None); + } + + let body = ApplyGuardrailRequest { + content, + output_scope: Some(DEFAULT_OUTPUT_SCOPE), + source: match payload.stage() { + GuardrailStage::Input => "INPUT", + GuardrailStage::Output => "OUTPUT", + }, + }; + + Ok(Some((serde_json::to_vec(&body)?, message_indexes))) + } + + fn sign_request( + &self, + url: &Url, + body: &[u8], + time: SystemTime, + ) -> Result { + let mut headers = HeaderMap::new(); + headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + + let header_pairs = headers + .iter() + .map(|(name, value)| { + let value = value.to_str().map_err(|error| { + BedrockError::Signing(format!( + "bedrock guardrail produced non-utf8 header {}: {}", + name, error + )) + })?; + Ok((name.as_str().to_owned(), value.to_owned())) + }) + .collect::, BedrockError>>()?; + + let identity: Identity = Credentials::new( + self.runtime.access_key_id.clone(), + self.runtime.secret_access_key.clone(), + self.runtime.session_token.clone(), + None, + "aisix-bedrock-guardrail-static", + ) + .into(); + let signing_params = v4::SigningParams::builder() + .identity(&identity) + .region(self.runtime.region.as_str()) + .name("bedrock") + .time(time) + .settings(SigningSettings::default()) + .build() + .map_err(|error| BedrockError::Signing(error.to_string()))? + .into(); + let signable_request = SignableRequest::new( + Method::POST.as_str(), + url.as_str(), + header_pairs + .iter() + .map(|(name, value)| (name.as_str(), value.as_str())), + SignableBody::Bytes(body), + ) + .map_err(|error| BedrockError::Signing(error.to_string()))?; + + let mut signed_request = Request::builder().method(Method::POST).uri(url.as_str()); + for (name, value) in &headers { + signed_request = signed_request.header(name, value); + } + + let mut signed_request = signed_request + .body(()) + .map_err(|error| BedrockError::Signing(error.to_string()))?; + let (instructions, _signature) = sign(signable_request, &signing_params) + .map_err(|error| BedrockError::Signing(error.to_string()))? + .into_parts(); + instructions.apply_to_request_http1x(&mut signed_request); + + Ok(SignedRequest { + url: Url::parse(&signed_request.uri().to_string()) + .map_err(|error| BedrockError::Signing(error.to_string()))?, + headers: signed_request.headers().clone(), + }) + } +} + +impl GuardrailMeta for BedrockGuardrailRuntime { + fn name(&self) -> &'static str { + IDENTIFIER + } +} + +#[async_trait] +impl GuardrailRuntime for BedrockGuardrailRuntime { + type Error = BedrockError; + + async fn check( + &self, + payload: &GuardrailCheckPayload, + config: &BedrockGuardrailConfig, + ) -> Result { + let Some((body, message_indexes)) = self.build_request_body(payload)? else { + return Ok(GuardrailOutcome::Allow); + }; + + let url = self.build_apply_url(config)?; + let signed = self.sign_request(&url, &body, SystemTime::now())?; + let response = self + .client + .post(signed.url) + .headers(signed.headers) + .body(body) + .send() + .await?; + let status = response.status(); + let response_body = response.bytes().await?; + + if !status.is_success() { + let body = String::from_utf8_lossy(&response_body).into_owned(); + return Err(BedrockError::HttpStatus(status, body)); + } + + let response: ApplyGuardrailResponse = serde_json::from_slice(&response_body)?; + outcome_from_response(payload, &message_indexes, response) + } +} + +#[derive(Debug, Error)] +pub enum BedrockError { + #[error("bedrock guardrail payload is not supported: {0}")] + UnsupportedPayload(String), + #[error("failed to build bedrock guardrail url: {0}")] + Url(String), + #[error("failed to sign bedrock guardrail request: {0}")] + Signing(String), + #[error("bedrock guardrail request failed: {0}")] + Request(#[from] reqwest::Error), + #[error("failed to encode/decode bedrock guardrail json: {0}")] + Json(#[from] serde_json::Error), + #[error("bedrock guardrail returned HTTP {0}: {1}")] + HttpStatus(StatusCode, String), +} + +#[derive(Debug, Clone)] +struct SignedRequest { + url: Url, + headers: HeaderMap, +} + +#[derive(Debug, Serialize)] +struct ApplyGuardrailRequest { + content: Vec, + #[serde(rename = "outputScope", skip_serializing_if = "Option::is_none")] + output_scope: Option<&'static str>, + source: &'static str, +} + +#[derive(Debug, Serialize)] +struct ApplyGuardrailContentBlock { + text: ApplyGuardrailTextBlock, +} + +#[derive(Debug, Serialize)] +struct ApplyGuardrailTextBlock { + text: String, +} + +#[derive(Debug, Deserialize)] +struct ApplyGuardrailResponse { + action: String, + #[serde(rename = "actionReason")] + action_reason: Option, + #[serde(default)] + outputs: Vec, +} + +#[derive(Debug, Deserialize)] +struct ApplyGuardrailOutput { + text: String, +} + +fn encode_path_segment(segment: &str) -> String { + utf8_percent_encode(segment, PATH_SEGMENT_ENCODE_SET).to_string() +} + +fn build_url_for_endpoint(base_url: &str, endpoint_path: &str) -> Result { + let mut parsed = Url::parse(base_url).map_err(|error| BedrockError::Url(error.to_string()))?; + + let base_segments = parsed + .path_segments() + .map(|segments| { + segments + .filter(|segment| !segment.is_empty()) + .collect::>() + }) + .unwrap_or_default(); + let endpoint_segments = endpoint_path + .split('/') + .filter(|segment| !segment.is_empty()) + .collect::>(); + + let max_overlap = base_segments.len().min(endpoint_segments.len()); + let overlap = (1..=max_overlap) + .rev() + .find(|count| base_segments[base_segments.len() - count..] == endpoint_segments[..*count]) + .unwrap_or(0); + + let mut joined_segments = base_segments; + joined_segments.extend_from_slice(&endpoint_segments[overlap..]); + + parsed.set_path(&format!("/{}", joined_segments.join("/"))); + Ok(parsed) +} + +fn text_blocks_from_payload( + payload: &GuardrailCheckPayload, +) -> (Vec, Vec) { + let messages = match payload { + GuardrailCheckPayload::Input(payload) => &payload.messages, + GuardrailCheckPayload::Output(payload) => &payload.messages, + }; + + messages + .iter() + .enumerate() + .filter_map(|(index, message)| message_text_for_bedrock(message).map(|text| (index, text))) + .fold( + (Vec::new(), Vec::new()), + |(mut content, mut indexes), (index, text)| { + indexes.push(index); + content.push(ApplyGuardrailContentBlock { + text: ApplyGuardrailTextBlock { text }, + }); + (content, indexes) + }, + ) +} + +fn message_text_for_bedrock(message: &GuardrailMessage) -> Option { + match &message.content { + Some(GuardrailMessageContent::Text(text)) if !text.is_empty() => Some(text.clone()), + _ => None, + } +} + +fn outcome_from_response( + payload: &GuardrailCheckPayload, + message_indexes: &[usize], + response: ApplyGuardrailResponse, +) -> Result { + match response.action.as_str() { + "NONE" => Ok(GuardrailOutcome::Allow), + "GUARDRAIL_BLOCKED" => Ok(GuardrailOutcome::Block { + reason: response + .action_reason + .unwrap_or_else(|| "bedrock guardrail blocked".into()), + }), + "GUARDRAIL_INTERVENED" => { + if response.outputs.is_empty() { + return Ok(GuardrailOutcome::Block { + reason: response + .action_reason + .unwrap_or_else(|| "bedrock guardrail intervened".into()), + }); + } + + if response.outputs.len() != message_indexes.len() { + return Ok(GuardrailOutcome::Block { + reason: format!( + "bedrock guardrail returned {} rewritten outputs for {} text messages", + response.outputs.len(), + message_indexes.len() + ), + }); + } + + let mut rewritten = payload.clone(); + let messages = match &mut rewritten { + GuardrailCheckPayload::Input(payload) => &mut payload.messages, + GuardrailCheckPayload::Output(payload) => &mut payload.messages, + }; + + for (index, output) in message_indexes.iter().zip(response.outputs) { + messages[*index].content = Some(GuardrailMessageContent::Text(output.text)); + } + + Ok(GuardrailOutcome::Rewrite(rewritten)) + } + other => Err(BedrockError::UnsupportedPayload(format!( + "unknown bedrock guardrail action {other}" + ))), + } +} + +#[cfg(test)] +mod tests { + use std::time::{Duration, UNIX_EPOCH}; + + use reqwest::Url; + + use super::{ + ApplyGuardrailOutput, ApplyGuardrailResponse, BedrockError, BedrockGuardrailConfig, + BedrockGuardrailRuntime, BedrockRuntimeConfig, build_url_for_endpoint, + outcome_from_response, + }; + use crate::traits::{ + GuardrailCheckPayload, GuardrailMessage, GuardrailMessageContent, GuardrailOutcome, + GuardrailRole, InputGuardrailPayload, + }; + + fn runtime() -> BedrockGuardrailRuntime { + BedrockGuardrailRuntime::new(BedrockRuntimeConfig { + region: "us-east-1".into(), + access_key_id: "AKIA123".into(), + secret_access_key: "secret".into(), + session_token: Some("token".into()), + endpoint: Some("https://bedrock-runtime.us-east-1.amazonaws.com/guardrail".into()), + }) + } + + fn input_payload(text: &str) -> GuardrailCheckPayload { + GuardrailCheckPayload::Input(InputGuardrailPayload { + messages: vec![GuardrailMessage { + role: GuardrailRole::User, + content: Some(GuardrailMessageContent::Text(text.into())), + name: None, + tool_calls: None, + tool_call_id: None, + }], + }) + } + + #[test] + fn bedrock_runtime_config_debug_redacts_credentials() { + let config = BedrockRuntimeConfig { + region: "us-east-1".into(), + access_key_id: "AKIA123".into(), + secret_access_key: "secret".into(), + session_token: Some("token".into()), + endpoint: Some("https://bedrock-runtime.us-east-1.amazonaws.com".into()), + }; + + let output = format!("{config:?}"); + assert!(output.contains("[REDACTED]")); + assert!(!output.contains("AKIA123")); + assert!(!output.contains("secret_access_key: \"secret\"")); + assert!(!output.contains("session_token: Some(\"token\")")); + } + + #[test] + fn build_url_for_endpoint_handles_overlap_and_encoding() { + let url = build_url_for_endpoint( + "https://bedrock-runtime.us-east-1.amazonaws.com/guardrail", + "/guardrail/arn:aws:bedrock:us-east-1:123456789012:guardrail/my/guardrail/version/DRAFT/apply", + ) + .unwrap(); + + assert_eq!( + url.as_str(), + "https://bedrock-runtime.us-east-1.amazonaws.com/guardrail/arn:aws:bedrock:us-east-1:123456789012:guardrail/my/guardrail/version/DRAFT/apply" + ); + } + + #[test] + fn build_apply_url_percent_encodes_identifier_segments() { + let runtime = runtime(); + let url = runtime + .build_apply_url(&BedrockGuardrailConfig { + identifier: "guardrail/name".into(), + version: "DRAFT".into(), + }) + .unwrap(); + + assert!( + url.path() + .ends_with("/guardrail/guardrail%2Fname/version/DRAFT/apply") + ); + } + + #[test] + fn sign_request_adds_sigv4_authorization_header() { + let runtime = runtime(); + let signed = runtime + .sign_request( + &Url::parse( + "https://bedrock-runtime.us-east-1.amazonaws.com/guardrail/gr/version/1/apply", + ) + .unwrap(), + br#"{"content":[{"text":{"text":"hello"}}],"source":"INPUT"}"#, + UNIX_EPOCH + Duration::from_secs(1_700_000_000), + ) + .unwrap(); + + assert!( + signed + .headers + .get(http::header::AUTHORIZATION) + .and_then(|value| value.to_str().ok()) + .is_some_and(|value| value.starts_with("AWS4-HMAC-SHA256")) + ); + } + + #[test] + fn outcome_from_response_allows_when_no_intervention() { + let outcome = outcome_from_response( + &input_payload("hello"), + &[0], + ApplyGuardrailResponse { + action: "NONE".into(), + action_reason: None, + outputs: vec![], + }, + ) + .unwrap(); + + assert_eq!(outcome, GuardrailOutcome::Allow); + } + + #[test] + fn outcome_from_response_rewrites_text_messages() { + let outcome = outcome_from_response( + &input_payload("hello"), + &[0], + ApplyGuardrailResponse { + action: "GUARDRAIL_INTERVENED".into(), + action_reason: Some("filtered".into()), + outputs: vec![ApplyGuardrailOutput { + text: "hi there".into(), + }], + }, + ) + .unwrap(); + + assert_eq!( + outcome, + GuardrailOutcome::Rewrite(GuardrailCheckPayload::Input(InputGuardrailPayload { + messages: vec![GuardrailMessage { + role: GuardrailRole::User, + content: Some(GuardrailMessageContent::Text("hi there".into())), + name: None, + tool_calls: None, + tool_call_id: None, + }], + })) + ); + } + + #[test] + fn outcome_from_response_blocks_when_intervention_has_no_rewrite_output() { + let outcome = outcome_from_response( + &input_payload("hello"), + &[0], + ApplyGuardrailResponse { + action: "GUARDRAIL_INTERVENED".into(), + action_reason: Some("policy triggered".into()), + outputs: vec![], + }, + ) + .unwrap(); + + assert_eq!( + outcome, + GuardrailOutcome::Block { + reason: "policy triggered".into(), + } + ); + } + + #[test] + fn outcome_from_response_blocks_when_bedrock_reports_blocked() { + let outcome = outcome_from_response( + &input_payload("hello"), + &[0], + ApplyGuardrailResponse { + action: "GUARDRAIL_BLOCKED".into(), + action_reason: Some("policy triggered".into()), + outputs: vec![], + }, + ) + .unwrap(); + + assert_eq!( + outcome, + GuardrailOutcome::Block { + reason: "policy triggered".into(), + } + ); + } + + #[test] + fn outcome_from_response_rejects_unknown_actions() { + let error = outcome_from_response( + &input_payload("hello"), + &[0], + ApplyGuardrailResponse { + action: "MAYBE".into(), + action_reason: None, + outputs: vec![], + }, + ) + .unwrap_err(); + + assert!(matches!(error, BedrockError::UnsupportedPayload(_))); + } +} diff --git a/crates/aisix-guardrail/src/guardrails/mod.rs b/crates/aisix-guardrail/src/guardrails/mod.rs new file mode 100644 index 0000000..435d7a0 --- /dev/null +++ b/crates/aisix-guardrail/src/guardrails/mod.rs @@ -0,0 +1,13 @@ +pub mod bedrock; + +pub use bedrock::{BedrockGuardrailMeta, BedrockGuardrailRuntime, BedrockRuntimeConfig}; + +pub mod identifiers { + use super::bedrock; + + pub const BEDROCK: &str = bedrock::IDENTIFIER; +} + +pub mod configs { + pub use super::bedrock::BedrockGuardrailConfig; +} diff --git a/crates/aisix-guardrail/src/lib.rs b/crates/aisix-guardrail/src/lib.rs index f6ac8fc..d99c24e 100644 --- a/crates/aisix-guardrail/src/lib.rs +++ b/crates/aisix-guardrail/src/lib.rs @@ -1 +1,2 @@ +pub mod guardrails; pub mod traits; diff --git a/tests/package.json b/tests/package.json index 65e0bd1..a8c1c80 100644 --- a/tests/package.json +++ b/tests/package.json @@ -8,7 +8,7 @@ "test": "vitest run", "test:dev": "vitest" }, - "packageManager": "pnpm@11.0.9+sha512.34ce82e6780233cf9cad8685029a8f81d2e06196c5a9bad98879f7424940c6817c4e4524fb7d38b8553ceed48b9758b8ebaf1abd3600c232c4c8cf7366086f38", + "packageManager": "pnpm@11.1.0+sha512.0c44e842e5686b2c061a81adda8b2258bd8818e9704b2cf2c63d56b931a7b2e910092e085027003b96ca3911ab56a07f6df5abaed2be9925034cdd686a535b14", "devDependencies": { "@anthropic-ai/sdk": "^0.88.0", "@eslint/js": "^10.0.1",