diff --git a/src/proxy/guardrails.rs b/src/proxy/guardrails.rs index 3c02175..132cb07 100644 --- a/src/proxy/guardrails.rs +++ b/src/proxy/guardrails.rs @@ -9,6 +9,8 @@ use aisix_guardrail::{ use async_trait::async_trait; use thiserror::Error; +pub(crate) mod streaming; + use crate::{ config::entities::{Model, ResourceEntry, ResourceRegistry, guardrails::GuardrailConfig}, gateway::{ diff --git a/src/proxy/guardrails/streaming.rs b/src/proxy/guardrails/streaming.rs new file mode 100644 index 0000000..aef7c40 --- /dev/null +++ b/src/proxy/guardrails/streaming.rs @@ -0,0 +1,188 @@ +use std::collections::VecDeque; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub(crate) struct StreamCheckpoint(pub u64); + +#[derive(Debug, Clone, PartialEq, Eq)] +#[allow(dead_code)] +pub(crate) enum StreamGuardrailDecision { + Pending, + Allow { approved_through: StreamCheckpoint }, + Block { reason: String }, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum WholeResponseReplayAction { + Buffered(Chunk), + Emit(Chunk), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum WholeResponseReplayFinalize { + NeedsGuardrailCheck, + Finished, +} + +#[derive(Debug)] +pub(crate) struct WholeResponseReplay { + next_checkpoint: u64, + buffered_chunks: VecDeque<(StreamCheckpoint, Chunk)>, +} + +impl Default for WholeResponseReplay { + fn default() -> Self { + Self { + next_checkpoint: 0, + buffered_chunks: VecDeque::new(), + } + } +} + +impl WholeResponseReplay { + pub(crate) fn push(&mut self, chunk: Chunk) -> StreamGuardrailDecision { + let checkpoint = StreamCheckpoint(self.next_checkpoint); + self.next_checkpoint = self.next_checkpoint.saturating_add(1); + self.buffered_chunks.push_back((checkpoint, chunk)); + StreamGuardrailDecision::Pending + } + + pub(crate) fn allow_all(self) -> (StreamGuardrailDecision, VecDeque) { + let approved_through = self + .buffered_chunks + .back() + .map(|(checkpoint, _)| *checkpoint) + .unwrap_or(StreamCheckpoint(0)); + let buffered_chunks = self + .buffered_chunks + .into_iter() + .map(|(_, chunk)| chunk) + .collect(); + + ( + StreamGuardrailDecision::Allow { approved_through }, + buffered_chunks, + ) + } +} + +#[derive(Debug, Default)] +pub(crate) struct WholeResponseReplayDriver { + replay: Option>, + replay_queue: VecDeque, + upstream_finished: bool, +} + +impl WholeResponseReplayDriver { + pub(crate) fn new(enabled: bool) -> Self { + Self { + replay: enabled.then(WholeResponseReplay::default), + replay_queue: VecDeque::new(), + upstream_finished: false, + } + } + + pub(crate) fn take_replay_chunk(&mut self) -> Option { + self.replay_queue.pop_front() + } + + pub(crate) fn finish_upstream(&mut self) -> WholeResponseReplayFinalize { + if self.replay.is_some() { + WholeResponseReplayFinalize::NeedsGuardrailCheck + } else { + self.upstream_finished = true; + WholeResponseReplayFinalize::Finished + } + } + + pub(crate) fn is_upstream_finished(&self) -> bool { + self.upstream_finished + } + + pub(crate) fn is_buffering(&self) -> bool { + self.replay.is_some() + } +} + +impl WholeResponseReplayDriver { + pub(crate) fn push_upstream_chunk(&mut self, chunk: Chunk) -> WholeResponseReplayAction { + if let Some(replay) = self.replay.as_mut() { + let decision = replay.push(chunk.clone()); + debug_assert!(matches!(decision, StreamGuardrailDecision::Pending)); + WholeResponseReplayAction::Buffered(chunk) + } else { + WholeResponseReplayAction::Emit(chunk) + } + } + + pub(crate) fn approve_buffered(&mut self) -> StreamGuardrailDecision { + let Some(replay) = self.replay.take() else { + debug_assert!( + false, + "approve_buffered called without buffered replay state" + ); + return StreamGuardrailDecision::Allow { + approved_through: StreamCheckpoint(0), + }; + }; + + let (decision, drained_chunks) = replay.allow_all(); + self.replay_queue = drained_chunks; + self.upstream_finished = true; + decision + } +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + + use super::{ + StreamGuardrailDecision, WholeResponseReplayAction, WholeResponseReplayDriver, + WholeResponseReplayFinalize, + }; + + #[test] + fn whole_response_replay_driver_passes_through_when_disabled() { + let mut driver = WholeResponseReplayDriver::new(false); + + assert_eq!( + driver.push_upstream_chunk(7_u8), + WholeResponseReplayAction::Emit(7) + ); + assert!(!driver.is_buffering()); + assert_eq!( + driver.finish_upstream(), + WholeResponseReplayFinalize::Finished + ); + assert!(driver.is_upstream_finished()); + assert_eq!(driver.take_replay_chunk(), None); + } + + #[test] + fn whole_response_replay_driver_replays_buffered_chunks_after_approval() { + let mut driver = WholeResponseReplayDriver::new(true); + + assert_eq!( + driver.push_upstream_chunk(String::from("safe ")), + WholeResponseReplayAction::Buffered(String::from("safe ")), + ); + assert_eq!( + driver.push_upstream_chunk(String::from("response")), + WholeResponseReplayAction::Buffered(String::from("response")), + ); + assert!(driver.is_buffering()); + assert_eq!( + driver.finish_upstream(), + WholeResponseReplayFinalize::NeedsGuardrailCheck, + ); + assert!(matches!( + driver.approve_buffered(), + StreamGuardrailDecision::Allow { .. } + )); + assert!(!driver.is_buffering()); + assert!(driver.is_upstream_finished()); + assert_eq!(driver.take_replay_chunk(), Some(String::from("safe "))); + assert_eq!(driver.take_replay_chunk(), Some(String::from("response"))); + assert_eq!(driver.take_replay_chunk(), None); + } +} diff --git a/src/proxy/handlers/chat_completions/mod.rs b/src/proxy/handlers/chat_completions/mod.rs index b1c2dee..15b2986 100644 --- a/src/proxy/handlers/chat_completions/mod.rs +++ b/src/proxy/handlers/chat_completions/mod.rs @@ -5,6 +5,7 @@ use axum::response::sse::Event as SseEvent; use fastrace::Span; use opentelemetry_semantic_conventions::attribute::GEN_AI_RESPONSE_FINISH_REASONS; use reqwest::Url; +use serde_json::json; use span_attributes::{ StreamOutputCollector, chunk_span_properties, request_span_properties, response_span_properties, }; @@ -30,6 +31,19 @@ use crate::{ pub(crate) struct ChatCompletionsAdapter; +fn openai_error_sse_event(message: String) -> SseEvent { + SseEvent::default().data( + json!({ + "error": { + "message": message, + "type": "invalid_request_error", + "code": "gateway_error", + } + }) + .to_string(), + ) +} + impl FormatHandlerAdapter for ChatCompletionsAdapter { type Format = OpenAIChatFormat; type Request = ChatCompletionRequest; @@ -155,6 +169,24 @@ impl FormatHandlerAdapter for ChatCompletionsAdapter { Ok(()) } + fn guardrail_stream_output_payload( + _lifecycle_state: &Self::LifecycleState, + collector: &Self::Collector, + ) -> Result, Self::Error> { + let payload = output_guardrail_payload_from_chat_messages(&collector.output_messages()) + .map(crate::guardrail::traits::GuardrailCheckPayload::Output) + .map_err(bridge_error)?; + Ok(Some(payload)) + } + + fn lifecycle_error_event(error: &Self::Error) -> Option { + let message = match error { + ChatCompletionError::GatewayError(err) => err.to_string(), + _ => error.to_string(), + }; + Some(openai_error_sse_event(message)) + } + fn end_of_stream_event(saw_item: bool) -> Option { saw_item.then(|| SseEvent::default().data("[DONE]")) } diff --git a/src/proxy/handlers/chat_completions/span_attributes/stream_output.rs b/src/proxy/handlers/chat_completions/span_attributes/stream_output.rs index 3ed883c..089fea5 100644 --- a/src/proxy/handlers/chat_completions/span_attributes/stream_output.rs +++ b/src/proxy/handlers/chat_completions/span_attributes/stream_output.rs @@ -2,7 +2,9 @@ use std::collections::BTreeMap; use super::message_attributes::{MessageContentView, MessageView, OutputMessageView, ToolCallView}; use crate::{ - gateway::types::openai::ChatCompletionChunk, + gateway::types::openai::{ + ChatCompletionChunk, ChatMessage, FunctionCall, MessageContent, ToolCall, + }, proxy::utils::trace::span_message_attributes::output_message_span_properties, }; @@ -70,6 +72,36 @@ impl StreamOutputCollector { output_message_span_properties(&self.output_message_views()) } + pub(crate) fn output_messages(&self) -> Vec { + self.choices + .values() + .map(|choice| ChatMessage { + role: choice.role.clone().unwrap_or_else(|| "assistant".into()), + content: (!choice.content.is_empty()) + .then(|| MessageContent::Text(choice.content.clone())), + name: None, + tool_calls: (!choice.tool_calls.is_empty()).then(|| { + choice + .tool_calls + .values() + .filter_map(|tool_call| { + let name = tool_call.name.clone()?; + Some(ToolCall { + id: tool_call.id.clone().unwrap_or_default(), + r#type: "function".into(), + function: FunctionCall { + name, + arguments: tool_call.arguments.clone(), + }, + }) + }) + .collect() + }), + tool_call_id: None, + }) + .collect() + } + fn output_message_views(&self) -> Vec { self.choices .values() diff --git a/src/proxy/handlers/format_handler.rs b/src/proxy/handlers/format_handler.rs index 2471df4..58441bf 100644 --- a/src/proxy/handlers/format_handler.rs +++ b/src/proxy/handlers/format_handler.rs @@ -1,4 +1,4 @@ -use std::{convert::Infallible, time::Duration}; +use std::{convert::Infallible, sync::Arc, time::Duration}; use async_trait::async_trait; use axum::{ @@ -20,7 +20,6 @@ use tokio::{ use crate::{ config::entities::{Model, ResourceEntry}, - guardrail::traits::{GuardrailCheckPayload, GuardrailOutcome}, gateway::{ error::GatewayError, traits::{ChatFormat, ProviderCapabilities}, @@ -29,9 +28,16 @@ use crate::{ response::{ChatResponse, ChatResponseStream}, }, }, + guardrail::traits::{GuardrailCheckPayload, GuardrailOutcome}, proxy::{ AppState, - guardrails::{ConfiguredGuardrailRuntime, resolve_model_guardrails}, + guardrails::{ + ConfiguredGuardrailRuntime, resolve_model_guardrails, + streaming::{ + StreamGuardrailDecision, WholeResponseReplayAction, WholeResponseReplayDriver, + WholeResponseReplayFinalize, + }, + }, hooks::{ self, RequestContext, authorization::AuthorizationError, rate_limit::RateLimitError, }, @@ -55,13 +61,14 @@ pub(crate) trait FormatHandlerAdapter: Send + Sync + 'static { >; type Request: Sync; type Response: Serialize; - type StreamChunk: Serialize + Send + 'static; + type StreamChunk: Clone + Serialize + Send + 'static; type Error: IntoResponse + std::fmt::Display + From + From + From - + From; + + From + + Send; type Collector: Default + Send + 'static; type LifecycleState: Default + Send + 'static; @@ -117,6 +124,13 @@ pub(crate) trait FormatHandlerAdapter: Send + Sync + 'static { Ok(()) } + fn guardrail_stream_output_payload( + _lifecycle_state: &Self::LifecycleState, + _collector: &Self::Collector, + ) -> Result, Self::Error> { + Ok(None) + } + async fn prepare_lifecycle( _state: &AppState, _request_ctx: &mut RequestContext, @@ -254,12 +268,19 @@ where mut response, usage, })) => { - apply_output_guardrails::( + let output_guardrail_result = apply_output_guardrails::( &configured_guardrails, &mut lifecycle_state, &mut response, ) - .await?; + .await; + apply_span_properties( + &span, + output_guardrail_failure_span_properties(&output_guardrail_result, || { + A::response_span_properties(&response, &usage) + }), + ); + output_guardrail_result?; A::handle_complete_response( &state, &mut request_ctx, @@ -274,6 +295,7 @@ where Ok(Ok(ChatResponse::Stream { stream, usage_rx })) => { handle_stream_response::( state, + configured_guardrails, stream, usage_rx, &mut request_ctx, @@ -363,6 +385,72 @@ where Ok(()) } +async fn apply_stream_output_guardrails( + guardrails: &[Box], + payload: &GuardrailCheckPayload, +) -> Result<(), A::Error> +where + A: FormatHandlerAdapter, +{ + for guardrail in guardrails { + let Some(outcome) = guardrail.check(payload).await? else { + continue; + }; + + match outcome { + GuardrailOutcome::Allow => {} + GuardrailOutcome::Rewrite(_) => { + return Err(GatewayError::Validation(format!( + "guardrail {} requested streaming output rewrite, which is not supported yet", + guardrail.name() + )) + .into()); + } + GuardrailOutcome::Block { reason } => { + return Err(GatewayError::Validation(format!( + "guardrail {} blocked output: {}", + guardrail.name(), + reason + )) + .into()); + } + } + } + + Ok(()) +} + +fn require_stream_output_guardrail_payload( + payload: Option, +) -> Result { + payload.ok_or_else(|| { + GatewayError::Internal( + "stream output guardrails were enabled, but the adapter did not provide a stream output payload" + .into(), + ) + }) +} + +fn has_output_guardrails(guardrails: &[Box]) -> bool { + guardrails + .iter() + .any(|guardrail| guardrail.supports_stage(crate::guardrail::traits::GuardrailStage::Output)) +} + +fn output_guardrail_failure_span_properties( + result: &Result<(), E>, + properties: F, +) -> Vec<(String, String)> +where + F: FnOnce() -> Vec<(String, String)>, +{ + if result.is_err() { + return properties(); + } + + Vec::new() +} + async fn handle_regular_response( response: AdapterResponse, usage: Usage, @@ -486,8 +574,28 @@ where A::handle_stream_success(state, request_ctx, lifecycle_state, None).await } +async fn record_first_stream_output_emit( + request_ctx: &mut RequestContext, + span: &Span, + first_output_arrived: &mut bool, + starts_output: bool, +) where + A: FormatHandlerAdapter, +{ + if *first_output_arrived || !starts_output { + return; + } + + *first_output_arrived = true; + hooks::observability::record_first_token_latency(request_ctx).await; + span.add_event( + TraceEvent::new("first token arrived").with_property(|| ("kind", "first_token_arrived")), + ); +} + async fn handle_stream_response( state: AppState, + configured_guardrails: Vec>, stream: ChatResponseStream>, usage_rx: oneshot::Receiver, request_ctx: &mut RequestContext, @@ -501,9 +609,13 @@ where let stream_request_ctx = request_ctx.clone(); let stream_state = state.clone(); + let replay_driver = + WholeResponseReplayDriver::new(has_output_guardrails(&configured_guardrails)); + let configured_guardrails = Arc::new(configured_guardrails); let sse_stream = futures::stream::unfold( ( stream_state, + configured_guardrails, stream, span, stream_request_ctx, @@ -511,28 +623,33 @@ where false, Some(usage_rx), AdapterCollector::::default(), + AdapterCollector::::default(), false, Some(lifecycle_state), + replay_driver, ), |( state, + configured_guardrails, mut stream, span, mut request_ctx, should_terminate, saw_item, mut usage_rx, + mut guardrail_output_collector, mut output_collector, mut first_output_arrived, mut lifecycle_state, + mut replay_driver, )| async move { if should_terminate { drop(span); return None; } - match stream.next().await { - Some(Ok(mut chunk)) => { + loop { + if let Some(mut chunk) = replay_driver.take_replay_chunk() { if let Some(lifecycle_state) = lifecycle_state.as_mut() { A::handle_stream_item( &state, @@ -542,74 +659,38 @@ where ); } - A::record_stream_item(&mut output_collector, &chunk); - - let now_starts_output = !first_output_arrived && A::starts_output(&chunk); - if now_starts_output { - first_output_arrived = true; - hooks::observability::record_first_token_latency(&mut request_ctx).await; - span.add_event( - TraceEvent::new("first token arrived") - .with_property(|| ("kind", "first_token_arrived")), - ); - } + record_first_stream_output_emit::( + &mut request_ctx, + &span, + &mut first_output_arrived, + A::starts_output(&chunk), + ) + .await; + A::record_stream_item(&mut output_collector, &chunk); A::apply_chunk_span_properties(&span, &chunk, !saw_item); - Some(( + break Some(( Ok::(A::serialize_stream_item(&chunk)), ( state, + configured_guardrails, stream, span, request_ctx, false, true, usage_rx, + guardrail_output_collector, output_collector, first_output_arrived, lifecycle_state, + replay_driver, ), - )) + )); } - Some(Err(err)) => { - error!("Gateway stream error: {}", err); - span.add_property(|| ("error.type", "stream_error")); - span.add_properties(|| A::output_message_span_properties(&output_collector)); - - if let Some(lifecycle_state) = lifecycle_state.take() { - if let Err(lifecycle_err) = - A::handle_stream_failure(&state, &mut request_ctx, lifecycle_state) - .await - { - error!("Stream failure lifecycle error: {}", lifecycle_err); - } - } - - finalize_stream_usage_observation(&mut request_ctx, &mut usage_rx, &span).await; - if let Some(event) = A::stream_error_event(&err) { - Some(( - Ok(event), - ( - state, - stream, - span, - request_ctx, - true, - saw_item, - usage_rx, - output_collector, - first_output_arrived, - lifecycle_state, - ), - )) - } else { - drop(span); - None - } - } - None => { + if replay_driver.is_upstream_finished() { match finalize_stream_success::( &state, &mut request_ctx, @@ -626,48 +707,248 @@ where span.add_property(|| ("error.type", "stream_success_lifecycle_error")); if let Some(event) = A::lifecycle_error_event(&err) { - return Some(( + break Some(( Ok(event), ( state, + configured_guardrails, stream, span, request_ctx, true, saw_item, usage_rx, + guardrail_output_collector, output_collector, first_output_arrived, lifecycle_state, + replay_driver, ), )); } drop(span); - return None; + break None; } } if let Some(event) = A::end_of_stream_event(saw_item) { - Some(( + break Some(( Ok(event), ( state, + configured_guardrails, stream, span, request_ctx, true, saw_item, usage_rx, + guardrail_output_collector, output_collector, first_output_arrived, lifecycle_state, + replay_driver, ), - )) - } else { + )); + } + + drop(span); + break None; + } + + match stream.next().await { + Some(Ok(chunk)) => match replay_driver.push_upstream_chunk(chunk) { + WholeResponseReplayAction::Buffered(chunk) => { + A::record_stream_item(&mut guardrail_output_collector, &chunk); + continue; + } + WholeResponseReplayAction::Emit(mut chunk) => { + if let Some(lifecycle_state) = lifecycle_state.as_mut() { + A::handle_stream_item( + &state, + &mut request_ctx, + lifecycle_state, + &mut chunk, + ); + } + + record_first_stream_output_emit::( + &mut request_ctx, + &span, + &mut first_output_arrived, + A::starts_output(&chunk), + ) + .await; + + A::record_stream_item(&mut output_collector, &chunk); + A::apply_chunk_span_properties(&span, &chunk, !saw_item); + + break Some(( + Ok::(A::serialize_stream_item(&chunk)), + ( + state, + configured_guardrails, + stream, + span, + request_ctx, + false, + true, + usage_rx, + guardrail_output_collector, + output_collector, + first_output_arrived, + lifecycle_state, + replay_driver, + ), + )); + } + }, + Some(Err(err)) => { + error!("Gateway stream error: {}", err); + span.add_property(|| ("error.type", "stream_error")); + if replay_driver.is_buffering() { + span.add_properties(|| { + A::output_message_span_properties(&guardrail_output_collector) + }); + } else { + span.add_properties(|| { + A::output_message_span_properties(&output_collector) + }); + } + + if let Some(lifecycle_state) = lifecycle_state.take() { + if let Err(lifecycle_err) = + A::handle_stream_failure(&state, &mut request_ctx, lifecycle_state) + .await + { + error!("Stream failure lifecycle error: {}", lifecycle_err); + } + } + + finalize_stream_usage_observation(&mut request_ctx, &mut usage_rx, &span) + .await; + + if let Some(event) = A::stream_error_event(&err) { + break Some(( + Ok(event), + ( + state, + configured_guardrails, + stream, + span, + request_ctx, + true, + saw_item, + usage_rx, + guardrail_output_collector, + output_collector, + first_output_arrived, + lifecycle_state, + replay_driver, + ), + )); + } + drop(span); - None + break None; } + None => match replay_driver.finish_upstream() { + WholeResponseReplayFinalize::NeedsGuardrailCheck => { + let output_guardrail_result = + if let Some(lifecycle_state) = lifecycle_state.as_ref() { + match A::guardrail_stream_output_payload( + lifecycle_state, + &guardrail_output_collector, + ) { + Ok(payload) => { + match require_stream_output_guardrail_payload(payload) { + Ok(payload) => { + apply_stream_output_guardrails::( + configured_guardrails.as_ref(), + &payload, + ) + .await + } + Err(err) => Err(err.into()), + } + } + Err(err) => Err(err), + } + } else { + Ok(()) + }; + + match output_guardrail_result { + Ok(()) => { + let decision = replay_driver.approve_buffered(); + debug_assert!(matches!( + decision, + StreamGuardrailDecision::Allow { .. } + )); + continue; + } + Err(err) => { + error!("Stream output guardrail error: {}", err); + span.add_property(|| { + ("error.type", "stream_output_guardrail_error") + }); + span.add_properties(|| { + A::output_message_span_properties( + &guardrail_output_collector, + ) + }); + + if let Some(lifecycle_state) = lifecycle_state.take() { + if let Err(lifecycle_err) = A::handle_stream_failure( + &state, + &mut request_ctx, + lifecycle_state, + ) + .await + { + error!( + "Stream failure lifecycle error: {}", + lifecycle_err + ); + } + } + + finalize_stream_usage_observation( + &mut request_ctx, + &mut usage_rx, + &span, + ) + .await; + + if let Some(event) = A::lifecycle_error_event(&err) { + break Some(( + Ok(event), + ( + state, + configured_guardrails, + stream, + span, + request_ctx, + true, + saw_item, + usage_rx, + guardrail_output_collector, + output_collector, + first_output_arrived, + lifecycle_state, + replay_driver, + ), + )); + } + + drop(span); + break None; + } + } + } + WholeResponseReplayFinalize::Finished => {} + }, } } }, @@ -677,3 +958,79 @@ where hooks::rate_limit::inject_response_headers(request_ctx, response.headers_mut()).await; Ok(response) } + +#[cfg(test)] +mod tests { + use std::cell::Cell; + + use pretty_assertions::assert_eq; + + use super::{ + output_guardrail_failure_span_properties, require_stream_output_guardrail_payload, + }; + use crate::{ + gateway::error::GatewayError, + guardrail::traits::{GuardrailCheckPayload, OutputGuardrailPayload}, + }; + + #[test] + fn output_guardrail_failure_span_properties_skips_computation_on_success() { + let computed = Cell::new(false); + + let properties = + output_guardrail_failure_span_properties(&Ok::<(), &'static str>(()), || { + computed.set(true); + vec![( + "llm.output_messages.0.message.content".into(), + "hidden".into(), + )] + }); + + assert!(properties.is_empty()); + assert!(!computed.get()); + } + + #[test] + fn output_guardrail_failure_span_properties_returns_properties_on_error() { + let computed = Cell::new(false); + + let properties = + output_guardrail_failure_span_properties(&Err::<(), &'static str>("blocked"), || { + computed.set(true); + vec![( + "llm.output_messages.0.message.content".into(), + "raw upstream output".into(), + )] + }); + + assert_eq!( + properties, + vec![( + "llm.output_messages.0.message.content".into(), + "raw upstream output".into(), + )] + ); + assert!(computed.get()); + } + + #[test] + fn require_stream_output_guardrail_payload_rejects_missing_payload() { + let err = require_stream_output_guardrail_payload(None).unwrap_err(); + + assert!(matches!(err, GatewayError::Internal(_))); + assert_eq!( + err.to_string(), + "internal: stream output guardrails were enabled, but the adapter did not provide a stream output payload", + ); + } + + #[test] + fn require_stream_output_guardrail_payload_passes_through_present_payload() { + let payload = GuardrailCheckPayload::Output(OutputGuardrailPayload::default()); + + assert_eq!( + require_stream_output_guardrail_payload(Some(payload.clone())).unwrap(), + payload, + ); + } +} diff --git a/src/proxy/handlers/messages/mod.rs b/src/proxy/handlers/messages/mod.rs index c7527a5..19f6544 100644 --- a/src/proxy/handlers/messages/mod.rs +++ b/src/proxy/handlers/messages/mod.rs @@ -169,6 +169,16 @@ impl FormatHandlerAdapter for MessagesAdapter { Ok(()) } + fn guardrail_stream_output_payload( + _lifecycle_state: &Self::LifecycleState, + collector: &Self::Collector, + ) -> Result, Self::Error> { + let payload = output_guardrail_payload_from_chat_messages(&collector.output_messages()) + .map(crate::guardrail::traits::GuardrailCheckPayload::Output) + .map_err(bridge_error)?; + Ok(Some(payload)) + } + fn serialize_stream_item(chunk: &Self::StreamChunk) -> SseEvent { serialize_stream_event(chunk) } @@ -176,6 +186,10 @@ impl FormatHandlerAdapter for MessagesAdapter { fn stream_error_event(error: &GatewayError) -> Option { Some(anthropic_error_sse_event(error.to_string())) } + + fn lifecycle_error_event(error: &Self::Error) -> Option { + Some(anthropic_error_sse_event(error.to_string())) + } } fn rewrite_anthropic_request_messages( diff --git a/src/proxy/handlers/messages/span_attributes/stream_output.rs b/src/proxy/handlers/messages/span_attributes/stream_output.rs index ba3c8e5..2f41d6c 100644 --- a/src/proxy/handlers/messages/span_attributes/stream_output.rs +++ b/src/proxy/handlers/messages/span_attributes/stream_output.rs @@ -5,7 +5,10 @@ use super::message_attributes::{ message_content_view_from_content_parts, }; use crate::{ - gateway::types::anthropic::{AnthropicContentBlock, AnthropicStreamEvent, ContentDelta}, + gateway::types::{ + anthropic::{AnthropicContentBlock, AnthropicStreamEvent, ContentDelta}, + openai::{ChatMessage, ContentPart, FunctionCall, ImageUrl, MessageContent, ToolCall}, + }, proxy::utils::trace::span_message_attributes::output_message_span_properties, }; @@ -105,6 +108,60 @@ impl StreamOutputCollector { output_message_span_properties(&self.output_message_views()) } + pub(crate) fn output_messages(&self) -> Vec { + if self.role.is_none() && self.blocks.is_empty() { + return Vec::new(); + } + + let mut content_parts = Vec::new(); + let mut tool_calls = Vec::new(); + + for block in self.blocks.values() { + match block { + StreamOutputBlock::Text(text) if !text.is_empty() => { + content_parts.push(ContentPart::Text { text: text.clone() }); + } + StreamOutputBlock::ImageUrl { url } => { + content_parts.push(ContentPart::ImageUrl { + image_url: ImageUrl { + url: url.clone(), + detail: None, + }, + }); + } + StreamOutputBlock::ToolUse { + id, + name, + arguments, + } if !name.is_empty() => { + tool_calls.push(ToolCall { + id: id.clone().unwrap_or_default(), + r#type: "function".into(), + function: FunctionCall { + name: name.clone(), + arguments: arguments.clone(), + }, + }); + } + StreamOutputBlock::Text(_) | StreamOutputBlock::ToolUse { .. } => {} + } + } + + let content = match content_parts.as_slice() { + [] => None, + [ContentPart::Text { text }] => Some(MessageContent::Text(text.clone())), + _ => Some(MessageContent::Parts(content_parts)), + }; + + vec![ChatMessage { + role: self.role.clone().unwrap_or_else(|| "assistant".into()), + content, + name: None, + tool_calls: (!tool_calls.is_empty()).then_some(tool_calls), + tool_call_id: None, + }] + } + fn output_message_views(&self) -> Vec { if self.role.is_none() && self.blocks.is_empty() && self.finish_reason.is_none() { return Vec::new(); diff --git a/src/proxy/handlers/responses/mod.rs b/src/proxy/handlers/responses/mod.rs index b5ac04b..8cccd6f 100644 --- a/src/proxy/handlers/responses/mod.rs +++ b/src/proxy/handlers/responses/mod.rs @@ -18,6 +18,7 @@ use span_attributes::{ }; pub use types::ResponsesError; +use super::FormatHandlerAdapter; use crate::{ gateway::{ error::GatewayError, @@ -42,8 +43,6 @@ use crate::{ }, }; -use super::FormatHandlerAdapter; - fn serialize_stream_event(event: &ResponsesApiStreamEvent) -> SseEvent { let mut sse_event = SseEvent::default().data(ResponsesApiFormat::serialize_chunk_payload(event)); @@ -111,9 +110,10 @@ impl FormatHandlerAdapter for ResponsesAdapter { lifecycle_state: &Self::LifecycleState, _request: &Self::Request, ) -> Result, Self::Error> { - let payload = input_guardrail_payload_from_chat_messages(&lifecycle_state.merged_input_messages) - .map(crate::guardrail::traits::GuardrailCheckPayload::Input) - .map_err(bridge_error)?; + let payload = + input_guardrail_payload_from_chat_messages(&lifecycle_state.merged_input_messages) + .map(crate::guardrail::traits::GuardrailCheckPayload::Input) + .map_err(bridge_error)?; Ok(Some(payload)) } @@ -122,9 +122,10 @@ impl FormatHandlerAdapter for ResponsesAdapter { request: &mut Self::Request, rewrite: crate::guardrail::traits::GuardrailCheckPayload, ) -> Result<(), Self::Error> { - let messages = input_payload_to_chat_messages(&input_payload_from_check_payload(rewrite) - .map_err(bridge_error)?) - .map_err(bridge_error)?; + let messages = input_payload_to_chat_messages( + &input_payload_from_check_payload(rewrite).map_err(bridge_error)?, + ) + .map_err(bridge_error)?; rewrite_request_from_messages(lifecycle_state, request, messages)?; Ok(()) } @@ -145,13 +146,24 @@ impl FormatHandlerAdapter for ResponsesAdapter { response: &mut Self::Response, rewrite: crate::guardrail::traits::GuardrailCheckPayload, ) -> Result<(), Self::Error> { - let messages = output_payload_to_chat_messages(&output_payload_from_check_payload(rewrite) - .map_err(bridge_error)?) - .map_err(bridge_error)?; + let messages = output_payload_to_chat_messages( + &output_payload_from_check_payload(rewrite).map_err(bridge_error)?, + ) + .map_err(bridge_error)?; rewrite_response_from_messages(response, &messages)?; Ok(()) } + fn guardrail_stream_output_payload( + _lifecycle_state: &Self::LifecycleState, + collector: &Self::Collector, + ) -> Result, Self::Error> { + let payload = output_guardrail_payload_from_chat_messages(&collector.output_messages()) + .map(crate::guardrail::traits::GuardrailCheckPayload::Output) + .map_err(bridge_error)?; + Ok(Some(payload)) + } + async fn prepare_lifecycle( state: &AppState, _request_ctx: &mut RequestContext, diff --git a/src/proxy/handlers/responses/span_attributes/stream_output.rs b/src/proxy/handlers/responses/span_attributes/stream_output.rs index 3953b97..45d5591 100644 --- a/src/proxy/handlers/responses/span_attributes/stream_output.rs +++ b/src/proxy/handlers/responses/span_attributes/stream_output.rs @@ -2,10 +2,17 @@ use std::collections::BTreeMap; use super::message_attributes::output_message_views_from_output_items; use crate::{ - gateway::types::openai::responses::{ - ResponsesApiResponse, ResponsesApiStreamEvent, ResponsesOutputContent, ResponsesOutputItem, + gateway::types::openai::{ + ChatMessage, + responses::{ + ResponsesApiResponse, ResponsesApiStreamEvent, ResponsesOutputContent, + ResponsesOutputItem, + }, + }, + proxy::{ + handlers::responses::runtime::response_output_to_chat_messages, + utils::trace::span_message_attributes::output_message_span_properties, }, - proxy::utils::trace::span_message_attributes::output_message_span_properties, }; #[derive(Default)] @@ -116,6 +123,15 @@ impl StreamOutputCollector { output_message_span_properties(&output_message_views_from_output_items(&output)) } + pub(crate) fn output_messages(&self) -> Vec { + if let Some(response) = &self.completed_response { + return response_output_to_chat_messages(&response.output); + } + + let output: Vec<_> = self.items.values().cloned().collect(); + response_output_to_chat_messages(&output) + } + fn sync_response_output(&mut self, response: &ResponsesApiResponse) { for (output_index, item) in response.output.iter().cloned().enumerate() { self.items.insert(output_index, item); diff --git a/tests/proxy/guardrail/chat-completions.test.ts b/tests/proxy/guardrail/chat-completions.test.ts index 3e33f09..2a7eba6 100644 --- a/tests/proxy/guardrail/chat-completions.test.ts +++ b/tests/proxy/guardrail/chat-completions.test.ts @@ -1,4 +1,8 @@ -import { proxyPost } from '../../utils/proxy.js'; +import { parseSseDataEvents, proxyPost } from '../../utils/proxy.js'; +import { + expectStreamMatchesAssistantText, + expectStreamStopsBeforeDone, +} from '../../utils/stream-assert.js'; import { type RegexGuardrailFixture, setupOpenAiRegexGuardrailFixture, @@ -106,4 +110,132 @@ describe('proxy guardrail /v1/chat/completions', () => { ).messages[0]?.content, ).toBe('safe prompt for output guardrail'); }); + + test('output regex guardrail blocks matched upstream streamed response', async () => { + const resp = await proxyPost( + '/v1/chat/completions', + { + model: fixture?.outputGuardedModelName, + stream: true, + messages: [ + { + role: 'user', + content: 'safe prompt for streamed output guardrail', + }, + ], + }, + AUTHORIZED_KEY, + { responseType: 'text' }, + ); + + expect(resp.status).toBe(200); + expect(String(resp.headers['content-type'])).toContain('text/event-stream'); + + const events = expectStreamStopsBeforeDone(String(resp.data)); + const payloads = parseSseDataEvents(String(resp.data)).map( + (item) => + JSON.parse(item) as { + error?: { + message?: string; + type?: string; + code?: string; + }; + }, + ); + + expect(events).toHaveLength(1); + expect(payloads[0]?.error?.type).toBe('invalid_request_error'); + expect(payloads[0]?.error?.code).toBe('gateway_error'); + expect(payloads[0]?.error?.message).toContain( + 'guardrail regex blocked output', + ); + expect(payloads[0]?.error?.message).toContain( + 'blocked by regex output guardrail', + ); + + const recorded = fixture?.upstream.takeRecordedRequests() ?? []; + expect(recorded).toHaveLength(1); + }); + + test('approved streamed output replays original chunks after guardrail check', async () => { + fixture?.upstream.configure({ + streamEvents: [ + { + id: 'chatcmpl-e2e-mock', + object: 'chat.completion.chunk', + created: 1, + model: UPSTREAM_MODEL, + choices: [ + { + index: 0, + delta: { role: 'assistant', content: 'safe ' }, + finish_reason: null, + }, + ], + }, + { + id: 'chatcmpl-e2e-mock', + object: 'chat.completion.chunk', + created: 1, + model: UPSTREAM_MODEL, + choices: [ + { + index: 0, + delta: { content: 'streamed response' }, + finish_reason: null, + }, + ], + }, + { + id: 'chatcmpl-e2e-mock', + object: 'chat.completion.chunk', + created: 1, + model: UPSTREAM_MODEL, + choices: [ + { + index: 0, + delta: {}, + finish_reason: 'stop', + }, + ], + }, + { + id: 'chatcmpl-e2e-mock', + object: 'chat.completion.chunk', + created: 1, + model: UPSTREAM_MODEL, + choices: [], + usage: { + prompt_tokens: 10, + completion_tokens: 8, + total_tokens: 18, + }, + }, + '[DONE]', + ], + }); + + const resp = await proxyPost( + '/v1/chat/completions', + { + model: fixture?.outputGuardedModelName, + stream: true, + messages: [ + { role: 'user', content: 'safe prompt for streamed output replay' }, + ], + }, + AUTHORIZED_KEY, + { responseType: 'text' }, + ); + + expect(resp.status).toBe(200); + expect(String(resp.headers['content-type'])).toContain('text/event-stream'); + expectStreamMatchesAssistantText( + String(resp.data), + 'safe streamed response', + ); + + const recorded = fixture?.upstream.takeRecordedRequests() ?? []; + expect(recorded).toHaveLength(1); + }); }); diff --git a/tests/proxy/guardrail/messages.test.ts b/tests/proxy/guardrail/messages.test.ts index 68279a5..6da3daa 100644 --- a/tests/proxy/guardrail/messages.test.ts +++ b/tests/proxy/guardrail/messages.test.ts @@ -10,6 +10,28 @@ const AUTHORIZED_KEY = 'sk-proxy-guardrail-messages'; const UPSTREAM_API_KEY = 'upstream-key-guardrail-messages'; const UPSTREAM_MODEL = 'test-model'; +const parseAnthropicSseEvents = (sseBody: string) => { + const trimmed = sseBody.trim(); + if (!trimmed) { + return [] as Array<{ event?: string; data: string }>; + } + + return trimmed.split(/\r?\n\r?\n/).map((block) => { + const lines = block + .split(/\r?\n/) + .map((line) => line.trim()) + .filter(Boolean); + + return { + event: lines.find((line) => line.startsWith('event: '))?.slice(7), + data: lines + .filter((line) => line.startsWith('data: ')) + .map((line) => line.slice(6)) + .join('\n'), + }; + }); +}; + describe('proxy guardrail /v1/messages', () => { let fixture: RegexGuardrailFixture | undefined; @@ -106,4 +128,138 @@ describe('proxy guardrail /v1/messages', () => { ).messages[0]?.content, ).toBe('safe prompt for output guardrail'); }); + + test('output regex guardrail blocks matched streamed messages response', async () => { + const resp = await proxyPost( + '/v1/messages', + { + model: fixture?.outputGuardedModelName, + max_tokens: 256, + stream: true, + messages: [ + { + role: 'user', + content: 'safe prompt for streamed output guardrail', + }, + ], + }, + AUTHORIZED_KEY, + { responseType: 'text' }, + ); + + expect(resp.status).toBe(200); + expect(String(resp.headers['content-type'])).toContain('text/event-stream'); + + const events = parseAnthropicSseEvents(String(resp.data)); + expect(events).toHaveLength(1); + expect(events[0]?.event).toBe('error'); + + const payload = JSON.parse(events[0]?.data ?? '{}') as { + type?: string; + error?: { type?: string; message?: string }; + }; + expect(payload.type).toBe('error'); + expect(payload.error?.type).toBe('api_error'); + expect(payload.error?.message).toContain('guardrail regex blocked output'); + expect(payload.error?.message).toContain( + 'blocked by regex output guardrail', + ); + + const recorded = fixture?.upstream.takeRecordedRequests() ?? []; + expect(recorded).toHaveLength(1); + }); + + test('approved streamed messages output replays original anthropic events', async () => { + fixture?.upstream.configure({ + streamEvents: [ + { + id: 'chatcmpl-e2e-mock', + object: 'chat.completion.chunk', + created: 1, + model: UPSTREAM_MODEL, + choices: [ + { + index: 0, + delta: { role: 'assistant', content: 'safe ' }, + finish_reason: null, + }, + ], + }, + { + id: 'chatcmpl-e2e-mock', + object: 'chat.completion.chunk', + created: 1, + model: UPSTREAM_MODEL, + choices: [ + { + index: 0, + delta: { content: 'streamed response' }, + finish_reason: null, + }, + ], + }, + { + id: 'chatcmpl-e2e-mock', + object: 'chat.completion.chunk', + created: 1, + model: UPSTREAM_MODEL, + choices: [ + { + index: 0, + delta: {}, + finish_reason: 'stop', + }, + ], + }, + { + id: 'chatcmpl-e2e-mock', + object: 'chat.completion.chunk', + created: 1, + model: UPSTREAM_MODEL, + choices: [], + usage: { + prompt_tokens: 10, + completion_tokens: 8, + total_tokens: 18, + }, + }, + '[DONE]', + ], + }); + + const resp = await proxyPost( + '/v1/messages', + { + model: fixture?.outputGuardedModelName, + max_tokens: 256, + stream: true, + messages: [ + { role: 'user', content: 'safe prompt for streamed output replay' }, + ], + }, + AUTHORIZED_KEY, + { responseType: 'text' }, + ); + + expect(resp.status).toBe(200); + expect(String(resp.headers['content-type'])).toContain('text/event-stream'); + + const events = parseAnthropicSseEvents(String(resp.data)); + expect(events[0]?.event).toBe('message_start'); + expect(events.some((event) => event.event === 'content_block_delta')).toBe( + true, + ); + expect(events.some((event) => event.event === 'message_delta')).toBe(true); + expect(events.at(-1)?.event).toBe('message_stop'); + + const streamedText = events + .filter((event) => event.event === 'content_block_delta') + .map((event) => JSON.parse(event.data) as { delta?: { text?: string } }) + .map((event) => event.delta?.text ?? '') + .join(''); + expect(streamedText).toBe('safe streamed response'); + + const recorded = fixture?.upstream.takeRecordedRequests() ?? []; + expect(recorded).toHaveLength(1); + }); }); diff --git a/tests/proxy/guardrail/responses.test.ts b/tests/proxy/guardrail/responses.test.ts index 6a60f77..e7afac3 100644 --- a/tests/proxy/guardrail/responses.test.ts +++ b/tests/proxy/guardrail/responses.test.ts @@ -9,6 +9,28 @@ const AUTHORIZED_KEY = 'sk-proxy-guardrail-responses'; const UPSTREAM_API_KEY = 'upstream-key-guardrail-responses'; const UPSTREAM_MODEL = 'test-model'; +const parseResponsesSseEvents = (sseBody: string) => { + const trimmed = sseBody.trim(); + if (!trimmed) { + return [] as Array<{ event?: string; data: string }>; + } + + return trimmed.split(/\r?\n\r?\n/).map((block) => { + const lines = block + .split(/\r?\n/) + .map((line) => line.trim()) + .filter(Boolean); + + return { + event: lines.find((line) => line.startsWith('event: '))?.slice(7), + data: lines + .filter((line) => line.startsWith('data: ')) + .map((line) => line.slice(6)) + .join('\n'), + }; + }); +}; + describe('proxy guardrail /v1/responses', () => { let fixture: RegexGuardrailFixture | undefined; @@ -103,4 +125,125 @@ describe('proxy guardrail /v1/responses', () => { ).messages[0]?.content, ).toBe('safe prompt for output guardrail'); }); + + test('output regex guardrail blocks matched streamed responses output', async () => { + const resp = await proxyPost( + '/v1/responses', + { + model: fixture?.outputGuardedModelName, + input: 'safe prompt for streamed output guardrail', + stream: true, + }, + AUTHORIZED_KEY, + { responseType: 'text' }, + ); + + expect(resp.status).toBe(200); + expect(String(resp.headers['content-type'])).toContain('text/event-stream'); + + const events = parseResponsesSseEvents(String(resp.data)); + expect(events).toHaveLength(1); + expect(events[0]?.event).toBe('error'); + + const payload = JSON.parse(events[0]?.data ?? '{}') as { + type?: string; + message?: string; + }; + expect(payload.type).toBe('error'); + expect(payload.message).toContain('guardrail regex blocked output'); + expect(payload.message).toContain('blocked by regex output guardrail'); + + const recorded = fixture?.upstream.takeRecordedRequests() ?? []; + expect(recorded).toHaveLength(1); + }); + + test('approved streamed responses output replays original response events', async () => { + fixture?.upstream.configure({ + streamEvents: [ + { + id: 'chatcmpl-e2e-mock', + object: 'chat.completion.chunk', + created: 1, + model: UPSTREAM_MODEL, + choices: [ + { + index: 0, + delta: { role: 'assistant', content: 'safe ' }, + finish_reason: null, + }, + ], + }, + { + id: 'chatcmpl-e2e-mock', + object: 'chat.completion.chunk', + created: 1, + model: UPSTREAM_MODEL, + choices: [ + { + index: 0, + delta: { content: 'streamed response' }, + finish_reason: null, + }, + ], + }, + { + id: 'chatcmpl-e2e-mock', + object: 'chat.completion.chunk', + created: 1, + model: UPSTREAM_MODEL, + choices: [ + { + index: 0, + delta: {}, + finish_reason: 'stop', + }, + ], + }, + { + id: 'chatcmpl-e2e-mock', + object: 'chat.completion.chunk', + created: 1, + model: UPSTREAM_MODEL, + choices: [], + usage: { + prompt_tokens: 10, + completion_tokens: 8, + total_tokens: 18, + }, + }, + '[DONE]', + ], + }); + + const resp = await proxyPost( + '/v1/responses', + { + model: fixture?.outputGuardedModelName, + input: 'safe prompt for streamed output replay', + stream: true, + }, + AUTHORIZED_KEY, + { responseType: 'text' }, + ); + + expect(resp.status).toBe(200); + expect(String(resp.headers['content-type'])).toContain('text/event-stream'); + + const events = parseResponsesSseEvents(String(resp.data)); + expect(events[0]?.event).toBe('response.created'); + expect( + events.some((event) => event.event === 'response.output_text.delta'), + ).toBe(true); + expect(events.at(-1)?.event).toBe('response.completed'); + + const streamedText = events + .filter((event) => event.event === 'response.output_text.delta') + .map((event) => JSON.parse(event.data) as { delta?: string }) + .map((event) => event.delta ?? '') + .join(''); + expect(streamedText).toBe('safe streamed response'); + + const recorded = fixture?.upstream.takeRecordedRequests() ?? []; + expect(recorded).toHaveLength(1); + }); });