diff --git a/src/proxy/guardrails.rs b/src/proxy/guardrails.rs new file mode 100644 index 0000000..ab8c488 --- /dev/null +++ b/src/proxy/guardrails.rs @@ -0,0 +1,686 @@ +use aisix_guardrail::{ + guardrails::BedrockGuardrailRuntime, + traits::{ + GuardrailCheckPayload, GuardrailContentPart, GuardrailImageUrl, GuardrailMessage, + GuardrailMessageContent, GuardrailOutcome, GuardrailRole, GuardrailRuntime, GuardrailStage, + GuardrailToolCall, InputGuardrailPayload, OutputGuardrailPayload, + }, +}; +use async_trait::async_trait; +use thiserror::Error; + +use crate::{ + config::entities::{Model, ResourceEntry, ResourceRegistry, guardrails::GuardrailConfig}, + gateway::{ + error::GatewayError, + types::openai::{ + ChatMessage, ContentPart, FunctionCall, ImageUrl, MessageContent, ToolCall, + }, + }, +}; + +#[derive(Debug, Error, PartialEq, Eq)] +pub(crate) enum GuardrailBridgeError { + #[error("unsupported chat message role: {0}")] + UnsupportedRole(String), + #[error("guardrail payload did not match the expected stage: {0}")] + UnexpectedStage(&'static str), +} + +#[cfg(test)] +#[derive(Debug, Error)] +pub(crate) enum GuardrailExecutionError +where + E: std::error::Error + 'static, +{ + #[error(transparent)] + Bridge(#[from] GuardrailBridgeError), + #[error(transparent)] + Runtime(E), +} + +#[async_trait] +pub(crate) trait ConfiguredGuardrailRuntime: Send + Sync { + fn name(&self) -> &'static str; + + fn supports_stage(&self, stage: GuardrailStage) -> bool; + + async fn check( + &self, + payload: &GuardrailCheckPayload, + ) -> Result, GatewayError>; +} + +struct GuardrailRuntimeHandle { + runtime: R, + config: C, +} + +impl GuardrailRuntimeHandle { + fn new(runtime: R, config: C) -> Self { + Self { runtime, config } + } +} + +#[async_trait] +impl ConfiguredGuardrailRuntime for GuardrailRuntimeHandle +where + R: GuardrailRuntime + Send + Sync, + C: Send + Sync, + R::Error: std::error::Error + Send + Sync + 'static, +{ + fn name(&self) -> &'static str { + self.runtime.name() + } + + fn supports_stage(&self, stage: GuardrailStage) -> bool { + self.runtime.supports_stage(stage) + } + + async fn check( + &self, + payload: &GuardrailCheckPayload, + ) -> Result, GatewayError> { + if !self.supports_stage(payload.stage()) { + return Ok(None); + } + + self.runtime + .check(payload, &self.config) + .await + .map(Some) + .map_err(|error| { + GatewayError::Internal(format!( + "guardrail {} failed: {}", + self.runtime.name(), + error + )) + }) + } +} + +pub(crate) fn chat_message_to_guardrail_message( + message: &ChatMessage, +) -> Result { + Ok(GuardrailMessage { + role: guardrail_role_from_chat_role(&message.role)?, + content: message + .content + .as_ref() + .map(guardrail_content_from_chat_content), + name: message.name.clone(), + tool_calls: message.tool_calls.as_ref().map(|tool_calls| { + tool_calls + .iter() + .map(|tool_call| GuardrailToolCall { + id: tool_call.id.clone(), + r#type: tool_call.r#type.clone(), + name: tool_call.function.name.clone(), + arguments: tool_call.function.arguments.clone(), + }) + .collect() + }), + tool_call_id: message.tool_call_id.clone(), + }) +} + +pub(crate) fn guardrail_message_to_chat_message( + message: &GuardrailMessage, +) -> Result { + Ok(ChatMessage { + role: chat_role_from_guardrail_role(message.role.clone()).to_owned(), + content: message + .content + .as_ref() + .map(chat_content_from_guardrail_content), + name: message.name.clone(), + tool_calls: message.tool_calls.as_ref().map(|tool_calls| { + tool_calls + .iter() + .map(|tool_call| ToolCall { + id: tool_call.id.clone(), + r#type: tool_call.r#type.clone(), + function: FunctionCall { + name: tool_call.name.clone(), + arguments: tool_call.arguments.clone(), + }, + }) + .collect() + }), + tool_call_id: message.tool_call_id.clone(), + }) +} + +pub(crate) fn input_guardrail_payload_from_chat_messages( + messages: &[ChatMessage], +) -> Result { + Ok(InputGuardrailPayload { + messages: messages + .iter() + .map(chat_message_to_guardrail_message) + .collect::, _>>()?, + }) +} + +pub(crate) fn output_guardrail_payload_from_chat_messages( + messages: &[ChatMessage], +) -> Result { + Ok(OutputGuardrailPayload { + messages: messages + .iter() + .map(chat_message_to_guardrail_message) + .collect::, _>>()?, + }) +} + +pub(crate) fn input_payload_to_chat_messages( + payload: &InputGuardrailPayload, +) -> Result, GuardrailBridgeError> { + payload + .messages + .iter() + .map(guardrail_message_to_chat_message) + .collect() +} + +pub(crate) fn output_payload_to_chat_messages( + payload: &OutputGuardrailPayload, +) -> Result, GuardrailBridgeError> { + payload + .messages + .iter() + .map(guardrail_message_to_chat_message) + .collect() +} + +pub(crate) fn input_payload_from_check_payload( + payload: GuardrailCheckPayload, +) -> Result { + match payload { + GuardrailCheckPayload::Input(payload) => Ok(payload), + GuardrailCheckPayload::Output(_) => Err(GuardrailBridgeError::UnexpectedStage("input")), + } +} + +pub(crate) fn output_payload_from_check_payload( + payload: GuardrailCheckPayload, +) -> Result { + match payload { + GuardrailCheckPayload::Output(payload) => Ok(payload), + GuardrailCheckPayload::Input(_) => Err(GuardrailBridgeError::UnexpectedStage("output")), + } +} + +pub(crate) fn resolve_model_guardrails( + model: &ResourceEntry, + resources: &ResourceRegistry, +) -> Result>, GatewayError> { + // This direct Model -> guardrail lookup is intentionally temporary. The long-term attachment + // point should come from policy evaluation so request-time guardrail selection is not encoded + // in the model resource itself. + model + .guardrail_ids + .iter() + .map(|guardrail_id| { + let guardrail = resources + .guardrails + .get_by_id(guardrail_id) + .ok_or_else(|| { + GatewayError::Internal(format!("guardrail {} not found", guardrail_id)) + })?; + + configured_guardrail_runtime_from_configs(&guardrail.guardrail) + }) + .collect() +} + +#[cfg(test)] +pub(crate) async fn run_guardrail_check( + runtime: &R, + config: &C, + payload: GuardrailCheckPayload, +) -> Result, GuardrailExecutionError> +where + R: GuardrailRuntime, +{ + if !runtime.supports_stage(payload.stage()) { + return Ok(None); + } + + Ok(Some( + runtime + .check(&payload, config) + .await + .map_err(GuardrailExecutionError::Runtime)?, + )) +} + +#[cfg(test)] +pub(crate) async fn run_input_guardrail_check( + runtime: &R, + config: &C, + messages: &[ChatMessage], +) -> Result, GuardrailExecutionError> +where + R: GuardrailRuntime, +{ + let payload = + GuardrailCheckPayload::Input(input_guardrail_payload_from_chat_messages(messages)?); + + run_guardrail_check(runtime, config, payload).await +} + +fn configured_guardrail_runtime_from_configs( + guardrail: &GuardrailConfig, +) -> Result, GatewayError> { + match guardrail { + GuardrailConfig::Bedrock(config) => Ok(Box::new(GuardrailRuntimeHandle::new( + BedrockGuardrailRuntime::new(), + config.clone(), + ))), + } +} + +fn guardrail_role_from_chat_role(role: &str) -> Result { + match role { + "system" => Ok(GuardrailRole::System), + "user" => Ok(GuardrailRole::User), + "assistant" => Ok(GuardrailRole::Assistant), + "tool" => Ok(GuardrailRole::Tool), + other => Err(GuardrailBridgeError::UnsupportedRole(other.to_string())), + } +} + +fn chat_role_from_guardrail_role(role: GuardrailRole) -> &'static str { + match role { + GuardrailRole::System => "system", + GuardrailRole::User => "user", + GuardrailRole::Assistant => "assistant", + GuardrailRole::Tool => "tool", + } +} + +fn guardrail_content_from_chat_content(content: &MessageContent) -> GuardrailMessageContent { + match content { + MessageContent::Text(text) => GuardrailMessageContent::Text(text.clone()), + MessageContent::Parts(parts) => GuardrailMessageContent::Parts( + parts + .iter() + .map(guardrail_content_part_from_chat_content_part) + .collect(), + ), + } +} + +fn chat_content_from_guardrail_content(content: &GuardrailMessageContent) -> MessageContent { + match content { + GuardrailMessageContent::Text(text) => MessageContent::Text(text.clone()), + GuardrailMessageContent::Parts(parts) => MessageContent::Parts( + parts + .iter() + .map(chat_content_part_from_guardrail_content_part) + .collect(), + ), + } +} + +fn guardrail_content_part_from_chat_content_part(part: &ContentPart) -> GuardrailContentPart { + match part { + ContentPart::Text { text } => GuardrailContentPart::Text { text: text.clone() }, + ContentPart::ImageUrl { image_url } => GuardrailContentPart::ImageUrl { + image_url: GuardrailImageUrl { + url: image_url.url.clone(), + detail: image_url.detail.clone(), + }, + }, + } +} + +fn chat_content_part_from_guardrail_content_part(part: &GuardrailContentPart) -> ContentPart { + match part { + GuardrailContentPart::Text { text } => ContentPart::Text { text: text.clone() }, + GuardrailContentPart::ImageUrl { image_url } => ContentPart::ImageUrl { + image_url: ImageUrl { + url: image_url.url.clone(), + detail: image_url.detail.clone(), + }, + }, + } +} + +#[cfg(test)] +mod tests { + use std::sync::Mutex; + + use aisix_guardrail::{ + guardrails::configs::BedrockGuardrailConfig, + traits::{ + GuardrailCheckPayload, GuardrailContentPart, GuardrailMessage, GuardrailMessageContent, + GuardrailMeta, GuardrailOutcome, GuardrailRole, GuardrailRuntime, GuardrailStage, + GuardrailToolCall, InputGuardrailPayload, OutputGuardrailPayload, + }, + }; + use async_trait::async_trait; + use thiserror::Error; + + use super::{ + GuardrailBridgeError, chat_message_to_guardrail_message, + configured_guardrail_runtime_from_configs, guardrail_message_to_chat_message, + input_guardrail_payload_from_chat_messages, input_payload_from_check_payload, + input_payload_to_chat_messages, output_guardrail_payload_from_chat_messages, + output_payload_from_check_payload, output_payload_to_chat_messages, + run_input_guardrail_check, + }; + use crate::{ + config::entities::guardrails::GuardrailConfig, + gateway::types::openai::{ + ChatMessage, ContentPart, FunctionCall, MessageContent, ToolCall, + }, + }; + + const INPUT_ONLY_STAGES: &[GuardrailStage] = &[GuardrailStage::Input]; + const OUTPUT_ONLY_STAGES: &[GuardrailStage] = &[GuardrailStage::Output]; + + #[derive(Debug, Error)] + #[error("mock guardrail runtime error")] + struct MockGuardrailError; + + struct RecordingGuardrailRuntime { + supported_stages: &'static [GuardrailStage], + outcome: GuardrailOutcome, + seen_payloads: Mutex>, + } + + impl RecordingGuardrailRuntime { + fn new(supported_stages: &'static [GuardrailStage], outcome: GuardrailOutcome) -> Self { + Self { + supported_stages, + outcome, + seen_payloads: Mutex::new(Vec::new()), + } + } + } + + impl GuardrailMeta for RecordingGuardrailRuntime { + fn name(&self) -> &'static str { + "recording" + } + + fn supported_stages(&self) -> &'static [GuardrailStage] { + self.supported_stages + } + } + + #[async_trait] + impl GuardrailRuntime<()> for RecordingGuardrailRuntime { + type Error = MockGuardrailError; + + async fn check( + &self, + payload: &GuardrailCheckPayload, + _config: &(), + ) -> Result { + self.seen_payloads.lock().unwrap().push(payload.clone()); + Ok(self.outcome.clone()) + } + } + + #[test] + fn chat_message_to_guardrail_message_preserves_core_fields() { + let message = ChatMessage { + role: "assistant".into(), + content: Some(MessageContent::Text("hello".into())), + name: Some("planner".into()), + tool_calls: Some(vec![ToolCall { + id: "call_weather_1".into(), + r#type: "function".into(), + function: FunctionCall { + name: "get_weather".into(), + arguments: r#"{"city":"Hangzhou"}"#.into(), + }, + }]), + tool_call_id: None, + }; + + let guardrail_message = chat_message_to_guardrail_message(&message).unwrap(); + + assert_eq!(guardrail_message.role, GuardrailRole::Assistant); + assert_eq!( + guardrail_message.content, + Some(GuardrailMessageContent::Text("hello".into())) + ); + assert_eq!(guardrail_message.name.as_deref(), Some("planner")); + assert_eq!( + guardrail_message.tool_calls, + Some(vec![GuardrailToolCall { + id: "call_weather_1".into(), + r#type: "function".into(), + name: "get_weather".into(), + arguments: r#"{"city":"Hangzhou"}"#.into(), + }]) + ); + } + + #[test] + fn guardrail_message_to_chat_message_round_trips_core_fields() { + let message = GuardrailMessage { + role: GuardrailRole::Assistant, + content: Some(GuardrailMessageContent::Parts(vec![ + GuardrailContentPart::Text { + text: "describe this image".into(), + }, + GuardrailContentPart::ImageUrl { + image_url: aisix_guardrail::traits::GuardrailImageUrl { + url: "https://example.com/cat.png".into(), + detail: Some("high".into()), + }, + }, + ])), + name: Some("planner".into()), + tool_calls: Some(vec![GuardrailToolCall { + id: "call_weather_1".into(), + r#type: "function".into(), + name: "get_weather".into(), + arguments: r#"{"city":"Hangzhou"}"#.into(), + }]), + tool_call_id: None, + }; + + let chat_message = guardrail_message_to_chat_message(&message).unwrap(); + + assert_eq!(chat_message.role, "assistant"); + assert_eq!(chat_message.name.as_deref(), Some("planner")); + match chat_message.content { + Some(MessageContent::Parts(parts)) => { + assert_eq!(parts.len(), 2); + assert!(matches!( + &parts[0], + ContentPart::Text { text } if text == "describe this image" + )); + assert!(matches!( + &parts[1], + ContentPart::ImageUrl { image_url } + if image_url.url == "https://example.com/cat.png" + && image_url.detail.as_deref() == Some("high") + )); + } + other => panic!("expected multipart content, got {other:?}"), + } + } + + #[test] + fn input_guardrail_payload_from_chat_messages_builds_message_list() { + let messages = vec![ + ChatMessage { + role: "system".into(), + content: Some(MessageContent::Text("be concise".into())), + name: None, + tool_calls: None, + tool_call_id: None, + }, + ChatMessage { + role: "tool".into(), + content: Some(MessageContent::Text(r#"{"ok":true}"#.into())), + name: None, + tool_calls: None, + tool_call_id: Some("call_weather_1".into()), + }, + ]; + + let payload = input_guardrail_payload_from_chat_messages(&messages).unwrap(); + + assert_eq!(payload.messages.len(), 2); + assert_eq!(payload.messages[0].role, GuardrailRole::System); + assert_eq!(payload.messages[1].role, GuardrailRole::Tool); + assert_eq!( + payload.messages[1].tool_call_id.as_deref(), + Some("call_weather_1") + ); + } + + #[test] + fn output_payload_round_trips_chat_messages() { + let messages = vec![ChatMessage { + role: "assistant".into(), + content: Some(MessageContent::Text("hello".into())), + name: None, + tool_calls: None, + tool_call_id: None, + }]; + + let payload = output_guardrail_payload_from_chat_messages(&messages).unwrap(); + let round_trip = output_payload_to_chat_messages(&payload).unwrap(); + + assert_eq!(round_trip.len(), 1); + assert_eq!(round_trip[0].role, "assistant"); + assert!(matches!( + &round_trip[0].content, + Some(MessageContent::Text(text)) if text == "hello" + )); + } + + #[test] + fn check_payload_stage_extractors_reject_mismatches() { + assert!(matches!( + input_payload_from_check_payload(GuardrailCheckPayload::Output( + OutputGuardrailPayload::default(), + )), + Err(GuardrailBridgeError::UnexpectedStage("input")) + )); + assert!(matches!( + output_payload_from_check_payload(GuardrailCheckPayload::Input( + InputGuardrailPayload::default(), + )), + Err(GuardrailBridgeError::UnexpectedStage("output")) + )); + } + + #[tokio::test] + async fn run_input_guardrail_check_should_bridge_and_call_runtime() { + let runtime = RecordingGuardrailRuntime::new(INPUT_ONLY_STAGES, GuardrailOutcome::Allow); + let messages = vec![ChatMessage { + role: "user".into(), + content: Some(MessageContent::Text("hello".into())), + name: None, + tool_calls: None, + tool_call_id: None, + }]; + + let outcome = run_input_guardrail_check(&runtime, &(), &messages) + .await + .unwrap(); + + assert_eq!(outcome, Some(GuardrailOutcome::Allow)); + assert_eq!( + runtime.seen_payloads.lock().unwrap().as_slice(), + &[GuardrailCheckPayload::Input(InputGuardrailPayload { + messages: vec![GuardrailMessage { + role: GuardrailRole::User, + content: Some(GuardrailMessageContent::Text("hello".into())), + name: None, + tool_calls: None, + tool_call_id: None, + }], + })] + ); + } + + #[tokio::test] + async fn run_input_guardrail_check_should_skip_unsupported_stage() { + let runtime = RecordingGuardrailRuntime::new(OUTPUT_ONLY_STAGES, GuardrailOutcome::Allow); + let messages = vec![ChatMessage { + role: "user".into(), + content: Some(MessageContent::Text("hello".into())), + name: None, + tool_calls: None, + tool_call_id: None, + }]; + + let outcome = run_input_guardrail_check(&runtime, &(), &messages) + .await + .unwrap(); + + assert_eq!(outcome, None); + assert!(runtime.seen_payloads.lock().unwrap().is_empty()); + } + + #[test] + fn configured_guardrail_runtime_from_configs_builds_bedrock_runtime() { + let runtime = configured_guardrail_runtime_from_configs(&GuardrailConfig::Bedrock( + BedrockGuardrailConfig { + identifier: "guardrail-123".into(), + version: "1".into(), + region: "us-east-1".into(), + access_key_id: "AKIA123".into(), + secret_access_key: "secret".into(), + session_token: None, + endpoint: None, + }, + )) + .unwrap(); + + assert_eq!(runtime.name(), "bedrock"); + assert!(runtime.supports_stage(GuardrailStage::Input)); + } + + #[test] + fn input_payload_to_chat_messages_round_trips() { + let payload = InputGuardrailPayload { + messages: vec![GuardrailMessage { + role: GuardrailRole::Tool, + content: Some(GuardrailMessageContent::Text(r#"{"ok":true}"#.into())), + name: None, + tool_calls: None, + tool_call_id: Some("call_weather_1".into()), + }], + }; + + let messages = input_payload_to_chat_messages(&payload).unwrap(); + + assert_eq!(messages.len(), 1); + assert_eq!(messages[0].role, "tool"); + assert!(matches!( + &messages[0].content, + Some(MessageContent::Text(text)) if text == r#"{"ok":true}"# + )); + assert_eq!(messages[0].tool_call_id.as_deref(), Some("call_weather_1")); + } + + #[test] + fn chat_message_to_guardrail_message_rejects_unsupported_roles() { + let message = ChatMessage { + role: "developer".into(), + content: Some(MessageContent::Text("hello".into())), + name: None, + tool_calls: None, + tool_call_id: None, + }; + + let error = chat_message_to_guardrail_message(&message).unwrap_err(); + + assert_eq!( + error, + GuardrailBridgeError::UnsupportedRole("developer".into()) + ); + } +} diff --git a/src/proxy/handlers/chat_completions/mod.rs b/src/proxy/handlers/chat_completions/mod.rs index 487a18a..b1c2dee 100644 --- a/src/proxy/handlers/chat_completions/mod.rs +++ b/src/proxy/handlers/chat_completions/mod.rs @@ -10,16 +10,22 @@ use span_attributes::{ }; pub use types::ChatCompletionError; +use super::FormatHandlerAdapter; use crate::{ gateway::{ + error::GatewayError, formats::OpenAIChatFormat, - traits::ProviderCapabilities, + traits::{ChatFormat, ProviderCapabilities}, types::{ common::Usage, openai::{ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse}, }, }, - proxy::handlers::format_handler::FormatHandlerAdapter, + proxy::guardrails::{ + input_guardrail_payload_from_chat_messages, input_payload_from_check_payload, + input_payload_to_chat_messages, output_guardrail_payload_from_chat_messages, + output_payload_from_check_payload, output_payload_to_chat_messages, + }, }; pub(crate) struct ChatCompletionsAdapter; @@ -86,7 +92,77 @@ impl FormatHandlerAdapter for ChatCompletionsAdapter { collector.output_message_span_properties() } + fn guardrail_input_payload( + _lifecycle_state: &Self::LifecycleState, + request: &Self::Request, + ) -> Result, Self::Error> { + let (hub_request, _) = OpenAIChatFormat::to_hub(request)?; + let payload = input_guardrail_payload_from_chat_messages(&hub_request.messages) + .map(crate::guardrail::traits::GuardrailCheckPayload::Input) + .map_err(bridge_error)?; + Ok(Some(payload)) + } + + fn apply_input_guardrail_rewrite( + _lifecycle_state: &mut Self::LifecycleState, + request: &mut Self::Request, + rewrite: crate::guardrail::traits::GuardrailCheckPayload, + ) -> Result<(), Self::Error> { + request.messages = input_payload_to_chat_messages( + &input_payload_from_check_payload(rewrite).map_err(bridge_error)?, + ) + .map_err(bridge_error)?; + Ok(()) + } + + fn guardrail_output_payload( + _lifecycle_state: &Self::LifecycleState, + response: &Self::Response, + ) -> Result, Self::Error> { + let messages = response + .choices + .iter() + .map(|choice| choice.message.clone()) + .collect::>(); + let payload = output_guardrail_payload_from_chat_messages(&messages) + .map(crate::guardrail::traits::GuardrailCheckPayload::Output) + .map_err(bridge_error)?; + Ok(Some(payload)) + } + + fn apply_output_guardrail_rewrite( + _lifecycle_state: &mut Self::LifecycleState, + response: &mut Self::Response, + rewrite: crate::guardrail::traits::GuardrailCheckPayload, + ) -> Result<(), Self::Error> { + let messages = output_payload_to_chat_messages( + &output_payload_from_check_payload(rewrite).map_err(bridge_error)?, + ) + .map_err(bridge_error)?; + + if messages.len() != response.choices.len() { + return Err(bridge_error(GatewayError::Bridge(format!( + "chat completion output guardrail rewrite expected {} messages, got {}", + response.choices.len(), + messages.len() + )))); + } + + for (choice, message) in response.choices.iter_mut().zip(messages) { + choice.message = message; + } + + Ok(()) + } + fn end_of_stream_event(saw_item: bool) -> Option { saw_item.then(|| SseEvent::default().data("[DONE]")) } } + +fn bridge_error(error: E) -> ChatCompletionError +where + E: std::fmt::Display, +{ + ChatCompletionError::GatewayError(GatewayError::Bridge(error.to_string())) +} diff --git a/src/proxy/handlers/format_handler.rs b/src/proxy/handlers/format_handler.rs index 81e228d..2471df4 100644 --- a/src/proxy/handlers/format_handler.rs +++ b/src/proxy/handlers/format_handler.rs @@ -20,6 +20,7 @@ use tokio::{ use crate::{ config::entities::{Model, ResourceEntry}, + guardrail::traits::{GuardrailCheckPayload, GuardrailOutcome}, gateway::{ error::GatewayError, traits::{ChatFormat, ProviderCapabilities}, @@ -30,6 +31,7 @@ use crate::{ }, proxy::{ AppState, + guardrails::{ConfiguredGuardrailRuntime, resolve_model_guardrails}, hooks::{ self, RequestContext, authorization::AuthorizationError, rate_limit::RateLimitError, }, @@ -85,6 +87,36 @@ pub(crate) trait FormatHandlerAdapter: Send + Sync + 'static { fn output_message_span_properties(collector: &Self::Collector) -> Vec<(String, String)>; + fn guardrail_input_payload( + _lifecycle_state: &Self::LifecycleState, + _request: &Self::Request, + ) -> Result, Self::Error> { + Ok(None) + } + + fn apply_input_guardrail_rewrite( + _lifecycle_state: &mut Self::LifecycleState, + _request: &mut Self::Request, + _rewrite: GuardrailCheckPayload, + ) -> Result<(), Self::Error> { + Ok(()) + } + + fn guardrail_output_payload( + _lifecycle_state: &Self::LifecycleState, + _response: &Self::Response, + ) -> Result, Self::Error> { + Ok(None) + } + + fn apply_output_guardrail_rewrite( + _lifecycle_state: &mut Self::LifecycleState, + _response: &mut Self::Response, + _rewrite: GuardrailCheckPayload, + ) -> Result<(), Self::Error> { + Ok(()) + } + async fn prepare_lifecycle( _state: &AppState, _request_ctx: &mut RequestContext, @@ -183,11 +215,19 @@ where let provider = model.provider(resources.as_ref()).ok_or_else(|| { GatewayError::Internal(format!("provider {} not found", model.provider_id)) })?; + let configured_guardrails = resolve_model_guardrails(&model, resources.as_ref())?; let provider_instance = create_provider_instance(gateway.as_ref(), &provider)?; let provider_base_url = provider_instance.effective_base_url().ok(); let mut lifecycle_state = A::prepare_lifecycle(&state, &mut request_ctx, &mut request_data).await?; + apply_input_guardrails::( + &configured_guardrails, + &mut lifecycle_state, + &mut request_data, + ) + .await?; + let span = Span::enter_with_local_parent(A::span_name()); apply_span_properties( &span, @@ -214,6 +254,12 @@ where mut response, usage, })) => { + apply_output_guardrails::( + &configured_guardrails, + &mut lifecycle_state, + &mut response, + ) + .await?; A::handle_complete_response( &state, &mut request_ctx, @@ -247,6 +293,76 @@ where } } +async fn apply_input_guardrails( + guardrails: &[Box], + lifecycle_state: &mut A::LifecycleState, + request: &mut AdapterRequest, +) -> Result<(), A::Error> +where + A: FormatHandlerAdapter, +{ + for guardrail in guardrails { + let Some(payload) = A::guardrail_input_payload(lifecycle_state, request)? else { + continue; + }; + let Some(outcome) = guardrail.check(&payload).await? else { + continue; + }; + + match outcome { + GuardrailOutcome::Allow => {} + GuardrailOutcome::Rewrite(rewrite) => { + A::apply_input_guardrail_rewrite(lifecycle_state, request, rewrite)?; + } + GuardrailOutcome::Block { reason } => { + return Err(GatewayError::Validation(format!( + "guardrail {} blocked input: {}", + guardrail.name(), + reason + )) + .into()); + } + } + } + + Ok(()) +} + +async fn apply_output_guardrails( + guardrails: &[Box], + lifecycle_state: &mut A::LifecycleState, + response: &mut AdapterResponse, +) -> Result<(), A::Error> +where + A: FormatHandlerAdapter, +{ + for guardrail in guardrails { + let Some(payload) = A::guardrail_output_payload(lifecycle_state, response)? else { + continue; + }; + let Some(outcome) = guardrail.check(&payload).await? else { + continue; + }; + + match outcome { + GuardrailOutcome::Allow => {} + GuardrailOutcome::Rewrite(rewrite) => { + A::apply_output_guardrail_rewrite(lifecycle_state, response, rewrite)?; + } + GuardrailOutcome::Block { reason } => { + return Err(GatewayError::Validation(format!( + "guardrail {} blocked output: {}", + guardrail.name(), + reason + )) + .into()); + } + } + } + + Ok(()) +} + async fn handle_regular_response( response: AdapterResponse, usage: Usage, diff --git a/src/proxy/handlers/messages/mod.rs b/src/proxy/handlers/messages/mod.rs index 3e00a52..c7527a5 100644 --- a/src/proxy/handlers/messages/mod.rs +++ b/src/proxy/handlers/messages/mod.rs @@ -4,11 +4,13 @@ mod types; use axum::response::sse::Event as SseEvent; use fastrace::Span; use reqwest::Url; +use serde_json::Value; use span_attributes::{ StreamOutputCollector, chunk_span_properties, request_span_properties, response_span_properties, }; pub use types::MessagesError; +use super::FormatHandlerAdapter; use crate::{ gateway::{ error::GatewayError, @@ -16,12 +18,19 @@ use crate::{ traits::{ChatFormat, ProviderCapabilities}, types::{ anthropic::{ + AnthropicContent, AnthropicContentBlock, AnthropicMessage, AnthropicMessagesRequest, AnthropicMessagesResponse, AnthropicStreamEvent, + ImageSource, SystemBlock, SystemPrompt, }, common::Usage, + openai::{ContentPart, FunctionCall, ImageUrl, MessageContent, ToolCall}, }, }, - proxy::handlers::format_handler::FormatHandlerAdapter, + proxy::guardrails::{ + input_guardrail_payload_from_chat_messages, input_payload_from_check_payload, + input_payload_to_chat_messages, output_guardrail_payload_from_chat_messages, + output_payload_from_check_payload, output_payload_to_chat_messages, + }, }; fn anthropic_error_sse_event(message: String) -> SseEvent { @@ -103,6 +112,63 @@ impl FormatHandlerAdapter for MessagesAdapter { collector.output_message_span_properties() } + fn guardrail_input_payload( + _lifecycle_state: &Self::LifecycleState, + request: &Self::Request, + ) -> Result, Self::Error> { + let (hub_request, _) = AnthropicMessagesFormat::to_hub(request)?; + let payload = input_guardrail_payload_from_chat_messages(&hub_request.messages) + .map(crate::guardrail::traits::GuardrailCheckPayload::Input) + .map_err(bridge_error)?; + Ok(Some(payload)) + } + + fn apply_input_guardrail_rewrite( + _lifecycle_state: &mut Self::LifecycleState, + request: &mut Self::Request, + rewrite: crate::guardrail::traits::GuardrailCheckPayload, + ) -> Result<(), Self::Error> { + let messages = input_payload_to_chat_messages( + &input_payload_from_check_payload(rewrite).map_err(bridge_error)?, + ) + .map_err(bridge_error)?; + rewrite_anthropic_request_messages(request, &messages).map_err(MessagesError::from) + } + + fn guardrail_output_payload( + _lifecycle_state: &Self::LifecycleState, + response: &Self::Response, + ) -> Result, Self::Error> { + let messages = + vec![anthropic_response_to_chat_message(response).map_err(MessagesError::from)?]; + let payload = output_guardrail_payload_from_chat_messages(&messages) + .map(crate::guardrail::traits::GuardrailCheckPayload::Output) + .map_err(bridge_error)?; + Ok(Some(payload)) + } + + fn apply_output_guardrail_rewrite( + _lifecycle_state: &mut Self::LifecycleState, + response: &mut Self::Response, + rewrite: crate::guardrail::traits::GuardrailCheckPayload, + ) -> Result<(), Self::Error> { + let messages = output_payload_to_chat_messages( + &output_payload_from_check_payload(rewrite).map_err(bridge_error)?, + ) + .map_err(bridge_error)?; + let [message] = messages.as_slice() else { + return Err(bridge_error(GatewayError::Bridge(format!( + "anthropic output guardrail rewrite expected exactly 1 message, got {}", + messages.len() + )))); + }; + + response.role = message.role.clone(); + response.content = + anthropic_blocks_from_chat_message(message).map_err(MessagesError::from)?; + Ok(()) + } + fn serialize_stream_item(chunk: &Self::StreamChunk) -> SseEvent { serialize_stream_event(chunk) } @@ -111,3 +177,330 @@ impl FormatHandlerAdapter for MessagesAdapter { Some(anthropic_error_sse_event(error.to_string())) } } + +fn rewrite_anthropic_request_messages( + request: &mut AnthropicMessagesRequest, + messages: &[crate::gateway::types::openai::ChatMessage], +) -> Result<(), GatewayError> { + let (system, anthropic_messages) = anthropic_request_parts_from_chat_messages(messages)?; + request.system = system; + request.messages = anthropic_messages; + Ok(()) +} + +fn anthropic_request_parts_from_chat_messages( + messages: &[crate::gateway::types::openai::ChatMessage], +) -> Result<(Option, Vec), GatewayError> { + let split_index = messages + .iter() + .position(|message| message.role != "system") + .unwrap_or(messages.len()); + + if messages[split_index..] + .iter() + .any(|message| message.role == "system") + { + return Err(GatewayError::Bridge( + "Anthropic request rewrite requires system messages to remain at the front".into(), + )); + } + + let system = system_prompt_from_chat_messages(&messages[..split_index])?; + let anthropic_messages = messages[split_index..] + .iter() + .map(chat_message_to_anthropic_message) + .collect::, _>>()?; + Ok((system, anthropic_messages)) +} + +fn system_prompt_from_chat_messages( + messages: &[crate::gateway::types::openai::ChatMessage], +) -> Result, GatewayError> { + if messages.is_empty() { + return Ok(None); + } + + let mut blocks = Vec::new(); + for message in messages { + for text in message_content_text_segments(message.content.as_ref())? { + blocks.push(SystemBlock { + r#type: "text".into(), + text, + cache_control: None, + }); + } + } + + match blocks.as_slice() { + [] => Ok(None), + [single] => Ok(Some(SystemPrompt::Text(single.text.clone()))), + _ => Ok(Some(SystemPrompt::Blocks(blocks))), + } +} + +fn chat_message_to_anthropic_message( + message: &crate::gateway::types::openai::ChatMessage, +) -> Result { + match message.role.as_str() { + "user" | "assistant" => Ok(AnthropicMessage { + role: message.role.clone(), + content: anthropic_content_from_chat_message(message)?, + }), + "tool" => Ok(AnthropicMessage { + role: "user".into(), + content: AnthropicContent::Blocks(vec![AnthropicContentBlock::ToolResult { + tool_use_id: message.tool_call_id.clone().ok_or_else(|| { + GatewayError::Bridge( + "tool message rewrite requires tool_call_id for Anthropic tool_result" + .into(), + ) + })?, + content: anthropic_content_from_optional_message_content(message.content.as_ref())?, + is_error: None, + cache_control: None, + }]), + }), + other => Err(GatewayError::Bridge(format!( + "unsupported role {} for Anthropic request rewrite", + other + ))), + } +} + +fn anthropic_content_from_chat_message( + message: &crate::gateway::types::openai::ChatMessage, +) -> Result { + let mut blocks = anthropic_blocks_from_message_content(message.content.as_ref())?; + + if let Some(tool_calls) = &message.tool_calls { + if message.role != "assistant" { + return Err(GatewayError::Bridge( + "only assistant messages can carry tool calls in Anthropic rewrite".into(), + )); + } + + for tool_call in tool_calls { + if tool_call.r#type != "function" { + return Err(GatewayError::Bridge(format!( + "Anthropic rewrite only supports function tool calls, got {}", + tool_call.r#type + ))); + } + + let input = serde_json::from_str(&tool_call.function.arguments).map_err(|error| { + GatewayError::Bridge(format!( + "assistant tool call arguments are not valid JSON: {}", + error + )) + })?; + + blocks.push(AnthropicContentBlock::ToolUse { + id: tool_call.id.clone(), + name: tool_call.function.name.clone(), + input, + cache_control: None, + }); + } + } + + anthropic_content_from_blocks(blocks) +} + +fn anthropic_content_from_optional_message_content( + content: Option<&MessageContent>, +) -> Result, GatewayError> { + let blocks = anthropic_blocks_from_message_content(content)?; + if blocks.is_empty() { + Ok(None) + } else { + anthropic_content_from_blocks(blocks).map(Some) + } +} + +fn anthropic_content_from_blocks( + blocks: Vec, +) -> Result { + match blocks.as_slice() { + [] => Err(GatewayError::Bridge( + "Anthropic rewrite requires at least one content block".into(), + )), + [AnthropicContentBlock::Text { text, .. }] => Ok(AnthropicContent::Text(text.clone())), + _ => Ok(AnthropicContent::Blocks(blocks)), + } +} + +fn anthropic_blocks_from_chat_message( + message: &crate::gateway::types::openai::ChatMessage, +) -> Result, GatewayError> { + if message.role != "assistant" { + return Err(GatewayError::Bridge(format!( + "Anthropic response rewrite requires an assistant message, got {}", + message.role + ))); + } + + let mut blocks = anthropic_blocks_from_message_content(message.content.as_ref())?; + if let Some(tool_calls) = &message.tool_calls { + for tool_call in tool_calls { + if tool_call.r#type != "function" { + return Err(GatewayError::Bridge(format!( + "Anthropic response rewrite only supports function tool calls, got {}", + tool_call.r#type + ))); + } + let input: Value = + serde_json::from_str(&tool_call.function.arguments).map_err(|error| { + GatewayError::Bridge(format!( + "assistant tool call arguments are not valid JSON: {}", + error + )) + })?; + blocks.push(AnthropicContentBlock::ToolUse { + id: tool_call.id.clone(), + name: tool_call.function.name.clone(), + input, + cache_control: None, + }); + } + } + Ok(blocks) +} + +fn anthropic_blocks_from_message_content( + content: Option<&MessageContent>, +) -> Result, GatewayError> { + let Some(content) = content else { + return Ok(vec![]); + }; + + match content { + MessageContent::Text(text) => Ok(vec![AnthropicContentBlock::Text { + text: text.clone(), + cache_control: None, + }]), + MessageContent::Parts(parts) => parts + .iter() + .map(|part| match part { + ContentPart::Text { text } => Ok(AnthropicContentBlock::Text { + text: text.clone(), + cache_control: None, + }), + ContentPart::ImageUrl { image_url } => Ok(AnthropicContentBlock::Image { + source: image_url_to_source(&image_url.url)?, + cache_control: None, + }), + }) + .collect(), + } +} + +fn anthropic_response_to_chat_message( + response: &AnthropicMessagesResponse, +) -> Result { + let mut text_segments = Vec::new(); + let mut rich_parts = Vec::new(); + let mut tool_calls = Vec::new(); + let mut has_non_text_part = false; + + for block in &response.content { + match block { + AnthropicContentBlock::Text { text, .. } => { + text_segments.push(text.clone()); + rich_parts.push(ContentPart::Text { text: text.clone() }); + } + AnthropicContentBlock::Image { source, .. } => { + has_non_text_part = true; + rich_parts.push(ContentPart::ImageUrl { + image_url: ImageUrl { + url: format!("data:{};base64,{}", source.media_type, source.data), + detail: None, + }, + }); + } + AnthropicContentBlock::ToolUse { + id, name, input, .. + } => { + tool_calls.push(ToolCall { + id: id.clone(), + r#type: "function".into(), + function: FunctionCall { + name: name.clone(), + arguments: serde_json::to_string(input) + .map_err(|error| GatewayError::Transform(error.to_string()))?, + }, + }); + } + AnthropicContentBlock::ToolResult { .. } => { + return Err(GatewayError::Bridge( + "assistant response contained unsupported tool_result block".into(), + )); + } + } + } + + Ok(crate::gateway::types::openai::ChatMessage { + role: response.role.clone(), + content: if has_non_text_part { + Some(MessageContent::Parts(rich_parts)) + } else if !text_segments.is_empty() { + Some(MessageContent::Text(text_segments.join(""))) + } else { + None + }, + name: None, + tool_calls: (!tool_calls.is_empty()).then_some(tool_calls), + tool_call_id: None, + }) +} + +fn message_content_text_segments( + content: Option<&MessageContent>, +) -> Result, GatewayError> { + let Some(content) = content else { + return Ok(vec![]); + }; + + match content { + MessageContent::Text(text) => Ok(vec![text.clone()]), + MessageContent::Parts(parts) => parts + .iter() + .map(|part| match part { + ContentPart::Text { text } => Ok(text.clone()), + ContentPart::ImageUrl { .. } => Err(GatewayError::Bridge( + "Anthropic system prompt rewrite does not support image content".into(), + )), + }) + .collect(), + } +} + +fn image_url_to_source(url: &str) -> Result { + let Some(payload) = url.strip_prefix("data:") else { + return Err(GatewayError::Bridge( + "Anthropic rewrite only supports image_url data URLs for image content".into(), + )); + }; + let Some((metadata, data)) = payload.split_once(',') else { + return Err(GatewayError::Bridge( + "invalid data URL for Anthropic image content".into(), + )); + }; + let Some(media_type) = metadata.strip_suffix(";base64") else { + return Err(GatewayError::Bridge( + "Anthropic image content requires base64 data URLs".into(), + )); + }; + + Ok(ImageSource { + r#type: "base64".into(), + media_type: media_type.into(), + data: data.into(), + }) +} + +fn bridge_error(error: E) -> MessagesError +where + E: std::fmt::Display, +{ + MessagesError::GatewayError(GatewayError::Bridge(error.to_string())) +} diff --git a/src/proxy/handlers/mod.rs b/src/proxy/handlers/mod.rs index 4257741..53e9ade 100644 --- a/src/proxy/handlers/mod.rs +++ b/src/proxy/handlers/mod.rs @@ -1,7 +1,9 @@ pub mod chat_completions; pub mod embeddings; -pub mod format_handler; +mod format_handler; pub mod messages; pub mod models; pub(crate) mod openai_error; pub mod responses; + +pub(crate) use format_handler::{FormatHandlerAdapter, format_handler}; diff --git a/src/proxy/handlers/responses/mod.rs b/src/proxy/handlers/responses/mod.rs index 9f6f424..b5ac04b 100644 --- a/src/proxy/handlers/responses/mod.rs +++ b/src/proxy/handlers/responses/mod.rs @@ -9,7 +9,8 @@ use reqwest::Url; use runtime::{ ResponsesLifecycleState, accumulate_complete, accumulate_stream_event, accumulate_stream_success, build_merged_input_messages, init_lifecycle, load_previous_messages, - persist_if_enabled, + persist_if_enabled, response_output_to_chat_messages, rewrite_request_from_messages, + rewrite_response_from_messages, }; use span_attributes::{ StreamOutputCollector, chunk_span_properties, event_starts_output, request_span_properties, @@ -30,9 +31,19 @@ use crate::{ }, }, }, - proxy::{AppState, handlers::format_handler::FormatHandlerAdapter, hooks::RequestContext}, + proxy::{ + AppState, + guardrails::{ + input_guardrail_payload_from_chat_messages, input_payload_from_check_payload, + input_payload_to_chat_messages, output_guardrail_payload_from_chat_messages, + output_payload_from_check_payload, output_payload_to_chat_messages, + }, + hooks::RequestContext, + }, }; +use super::FormatHandlerAdapter; + fn serialize_stream_event(event: &ResponsesApiStreamEvent) -> SseEvent { let mut sse_event = SseEvent::default().data(ResponsesApiFormat::serialize_chunk_payload(event)); @@ -96,6 +107,51 @@ impl FormatHandlerAdapter for ResponsesAdapter { collector.output_message_span_properties() } + fn guardrail_input_payload( + lifecycle_state: &Self::LifecycleState, + _request: &Self::Request, + ) -> Result, Self::Error> { + let payload = input_guardrail_payload_from_chat_messages(&lifecycle_state.merged_input_messages) + .map(crate::guardrail::traits::GuardrailCheckPayload::Input) + .map_err(bridge_error)?; + Ok(Some(payload)) + } + + fn apply_input_guardrail_rewrite( + lifecycle_state: &mut Self::LifecycleState, + request: &mut Self::Request, + rewrite: crate::guardrail::traits::GuardrailCheckPayload, + ) -> Result<(), Self::Error> { + let messages = input_payload_to_chat_messages(&input_payload_from_check_payload(rewrite) + .map_err(bridge_error)?) + .map_err(bridge_error)?; + rewrite_request_from_messages(lifecycle_state, request, messages)?; + Ok(()) + } + + fn guardrail_output_payload( + _lifecycle_state: &Self::LifecycleState, + response: &Self::Response, + ) -> Result, Self::Error> { + let messages = response_output_to_chat_messages(&response.output); + let payload = output_guardrail_payload_from_chat_messages(&messages) + .map(crate::guardrail::traits::GuardrailCheckPayload::Output) + .map_err(bridge_error)?; + Ok(Some(payload)) + } + + fn apply_output_guardrail_rewrite( + _lifecycle_state: &mut Self::LifecycleState, + response: &mut Self::Response, + rewrite: crate::guardrail::traits::GuardrailCheckPayload, + ) -> Result<(), Self::Error> { + let messages = output_payload_to_chat_messages(&output_payload_from_check_payload(rewrite) + .map_err(bridge_error)?) + .map_err(bridge_error)?; + rewrite_response_from_messages(response, &messages)?; + Ok(()) + } + async fn prepare_lifecycle( state: &AppState, _request_ctx: &mut RequestContext, @@ -106,6 +162,7 @@ impl FormatHandlerAdapter for ResponsesAdapter { let previous_messages = load_previous_messages(storage.as_ref(), request.previous_response_id.as_deref()) .await?; + lifecycle_state.replay_messages_len = previous_messages.len(); lifecycle_state.merged_input_messages = build_merged_input_messages(request, &previous_messages)?; request.replay_messages = previous_messages; @@ -205,3 +262,10 @@ fn rewrite_output_item_ids(item: &mut ResponsesOutputItem, response_id: &str, ou *id = format!("{}_message_{}", response_id, output_index); } } + +fn bridge_error(error: E) -> ResponsesError +where + E: std::fmt::Display, +{ + ResponsesError::GatewayError(GatewayError::Bridge(error.to_string())) +} diff --git a/src/proxy/handlers/responses/runtime.rs b/src/proxy/handlers/responses/runtime.rs index 4d035d8..f0d5dc5 100644 --- a/src/proxy/handlers/responses/runtime.rs +++ b/src/proxy/handlers/responses/runtime.rs @@ -27,6 +27,7 @@ use crate::{ pub(crate) struct ResponsesLifecycleState { pub response_id: String, pub previous_response_id: Option, + pub replay_messages_len: usize, pub merged_input_messages: Vec, pub model: String, pub metadata: HashMap, @@ -246,6 +247,7 @@ pub(crate) fn init_lifecycle(request: &ResponsesApiRequest) -> ResponsesLifecycl ResponsesLifecycleState { response_id: generate_response_id(), previous_response_id: request.previous_response_id.clone(), + replay_messages_len: 0, merged_input_messages: vec![], model: request.model.clone(), metadata: request_metadata(request), @@ -370,7 +372,7 @@ fn request_metadata(request: &ResponsesApiRequest) -> HashMap { .unwrap_or_default() } -fn request_input_messages(request: &ResponsesApiRequest) -> Result> { +pub(crate) fn request_input_messages(request: &ResponsesApiRequest) -> Result> { match &request.input { ResponsesInput::Text(text) => Ok(vec![ChatMessage { role: "user".into(), @@ -454,7 +456,7 @@ fn request_content_part_to_content_part( } } -fn response_output_to_chat_messages(output: &[ResponsesOutputItem]) -> Vec { +pub(crate) fn response_output_to_chat_messages(output: &[ResponsesOutputItem]) -> Vec { let mut messages = Vec::new(); let mut current_assistant_index = None; @@ -520,6 +522,38 @@ fn response_output_to_chat_messages(output: &[ResponsesOutputItem]) -> Vec, +) -> Result<()> { + if messages.len() < state.replay_messages_len { + return Err(GatewayError::Bridge(format!( + "responses guardrail rewrite returned {} messages, fewer than {} replay messages", + messages.len(), + state.replay_messages_len + ))); + } + + let replay_messages = messages[..state.replay_messages_len].to_vec(); + let current_messages = &messages[state.replay_messages_len..]; + let (instructions, input) = responses_request_body_from_messages(current_messages)?; + + state.merged_input_messages = messages; + request.replay_messages = replay_messages; + request.instructions = instructions; + request.input = input; + Ok(()) +} + +pub(crate) fn rewrite_response_from_messages( + response: &mut ResponsesApiResponse, + messages: &[ChatMessage], +) -> Result<()> { + response.output = chat_messages_to_response_output(&response.id, messages)?; + Ok(()) +} + fn response_output_content_to_message_content( content: &[ResponsesOutputContent], ) -> Option { @@ -539,6 +573,210 @@ fn response_output_content_to_message_content( } } +fn responses_request_body_from_messages( + messages: &[ChatMessage], +) -> Result<(Option, ResponsesInput)> { + let split_index = messages + .iter() + .position(|message| message.role != "system") + .unwrap_or(messages.len()); + + if messages[split_index..] + .iter() + .any(|message| message.role == "system") + { + return Err(GatewayError::Bridge( + "Responses request rewrite requires system messages to remain at the front".into(), + )); + } + + let instructions = if split_index == 0 { + None + } else { + let mut segments = Vec::new(); + for message in &messages[..split_index] { + segments.extend(message_content_text_segments(message.content.as_ref())?); + } + (!segments.is_empty()).then_some(segments.join("\n\n")) + }; + + let items = messages[split_index..] + .iter() + .map(chat_message_to_responses_input_item) + .collect::>>()?; + + Ok((instructions, ResponsesInput::Items(items))) +} + +fn chat_message_to_responses_input_item(message: &ChatMessage) -> Result { + match message.role.as_str() { + "user" | "assistant" | "system" => Ok(ResponsesInputItem::Message { + role: message.role.clone(), + content: message_content_to_responses_content(message.content.as_ref())?, + }), + "tool" => Ok(ResponsesInputItem::FunctionCallOutput { + call_id: message.tool_call_id.clone().ok_or_else(|| { + GatewayError::Bridge( + "Responses request rewrite requires tool_call_id for tool messages".into(), + ) + })?, + output: message_content_to_text(message.content.as_ref())?, + }), + other => Err(GatewayError::Bridge(format!( + "unsupported role {} for Responses request rewrite", + other + ))), + } +} + +fn message_content_to_responses_content(content: Option<&MessageContent>) -> Result { + let Some(content) = content else { + return Err(GatewayError::Bridge( + "Responses request rewrite requires message content".into(), + )); + }; + + match content { + MessageContent::Text(text) => Ok(ResponsesContent::Text(text.clone())), + MessageContent::Parts(parts) => Ok(ResponsesContent::Parts( + parts + .iter() + .map(|part| match part { + ContentPart::Text { text } => Ok(ResponsesContentPart::InputText { + text: text.clone(), + }), + ContentPart::ImageUrl { image_url } => Ok(ResponsesContentPart::InputImage { + image_url: Some(image_url.url.clone()), + file_id: None, + detail: image_url.detail.clone(), + }), + }) + .collect::>>()?, + )), + } +} + +fn message_content_to_text(content: Option<&MessageContent>) -> Result { + let Some(content) = content else { + return Ok(String::new()); + }; + + match content { + MessageContent::Text(text) => Ok(text.clone()), + MessageContent::Parts(parts) => { + let mut text = String::new(); + for part in parts { + match part { + ContentPart::Text { text: part_text } => text.push_str(part_text), + ContentPart::ImageUrl { .. } => { + return Err(GatewayError::Bridge( + "Responses text-only rewrite does not support image content here" + .into(), + )); + } + } + } + Ok(text) + } + } +} + +fn message_content_text_segments(content: Option<&MessageContent>) -> Result> { + let Some(content) = content else { + return Ok(vec![]); + }; + + match content { + MessageContent::Text(text) => Ok(vec![text.clone()]), + MessageContent::Parts(parts) => parts + .iter() + .map(|part| match part { + ContentPart::Text { text } => Ok(text.clone()), + ContentPart::ImageUrl { .. } => Err(GatewayError::Bridge( + "Responses instructions rewrite does not support image content".into(), + )), + }) + .collect(), + } +} + +fn chat_messages_to_response_output( + response_id: &str, + messages: &[ChatMessage], +) -> Result> { + let mut output = Vec::new(); + let mut next_output_index = 0; + + for message in messages { + let items = chat_message_to_response_output(response_id, next_output_index, message)?; + next_output_index += items.len(); + output.extend(items); + } + + Ok(output) +} + +fn chat_message_to_response_output( + response_id: &str, + next_output_index: usize, + message: &ChatMessage, +) -> Result> { + let mut output = Vec::new(); + let content = chat_message_content_to_response_output_content(message.content.as_ref())?; + + if !content.is_empty() { + output.push(ResponsesOutputItem::Message { + id: response_message_output_id(response_id, next_output_index), + role: message.role.clone(), + content, + status: "completed".into(), + }); + } + + if let Some(tool_calls) = &message.tool_calls { + let first_tool_output_index = next_output_index + output.len(); + for (offset, tool_call) in tool_calls.iter().enumerate() { + output.push(ResponsesOutputItem::FunctionCall { + id: response_function_call_output_id( + response_id, + first_tool_output_index + offset, + ), + call_id: tool_call.id.clone(), + name: tool_call.function.name.clone(), + arguments: tool_call.function.arguments.clone(), + status: "completed".into(), + }); + } + } + + Ok(output) +} + +fn chat_message_content_to_response_output_content( + content: Option<&MessageContent>, +) -> Result> { + let Some(content) = content else { + return Ok(vec![]); + }; + + match content { + MessageContent::Text(text) => Ok(vec![ResponsesOutputContent::OutputText { + text: text.clone(), + }]), + MessageContent::Parts(parts) => parts + .iter() + .map(|part| match part { + ContentPart::Text { text } => Ok(ResponsesOutputContent::OutputText { + text: text.clone(), + }), + ContentPart::ImageUrl { .. } => Err(GatewayError::Bridge( + "Responses output rewrite does not support image content".into(), + )), + }) + .collect(), + } +} + fn merge_metadata_value(metadata: &mut HashMap, extra: Option<&Value>) { let Some(extra) = extra.and_then(Value::as_object) else { return; diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 59033d8..6adbed0 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -1,3 +1,4 @@ +mod guardrails; mod handlers; mod hooks; pub(crate) mod message_history; @@ -64,19 +65,19 @@ pub fn create_router(state: AppState) -> Result { .route( "/v1/chat/completions", post( - handlers::format_handler::format_handler::< + handlers::format_handler::< handlers::chat_completions::ChatCompletionsAdapter, >, ), ) .route( "/v1/messages", - post(handlers::format_handler::format_handler::) + post(handlers::format_handler::) .layer(DefaultBodyLimit::max(32 * 1024 * 1024)), ) .route( "/v1/responses", - post(handlers::format_handler::format_handler::), + post(handlers::format_handler::), ) .route("/v1/embeddings", post(handlers::embeddings::embeddings)) .layer(DefaultBodyLimit::max(10 * 1024 * 1024)) diff --git a/tests/package.json b/tests/package.json index a8c1c80..ccd9543 100644 --- a/tests/package.json +++ b/tests/package.json @@ -8,7 +8,7 @@ "test": "vitest run", "test:dev": "vitest" }, - "packageManager": "pnpm@11.1.0+sha512.0c44e842e5686b2c061a81adda8b2258bd8818e9704b2cf2c63d56b931a7b2e910092e085027003b96ca3911ab56a07f6df5abaed2be9925034cdd686a535b14", + "packageManager": "pnpm@11.1.1+sha512.d1fdf5f73c617b64fa1a56a81c3c8dfe0e966e33a6010aa256b517ae77be21d93e05affc0de1a83b0e4f29d569f68b446ae8f068cd7247c0bb3df0fb4d7bdf9a", "devDependencies": { "@anthropic-ai/sdk": "^0.88.0", "@eslint/js": "^10.0.1",