diff --git a/src/proxy/handlers/chat_completions/mod.rs b/src/proxy/handlers/chat_completions/mod.rs index c729973..ab42886 100644 --- a/src/proxy/handlers/chat_completions/mod.rs +++ b/src/proxy/handlers/chat_completions/mod.rs @@ -1,297 +1,91 @@ mod span_attributes; mod types; -use std::{convert::Infallible, time::Duration}; - -use axum::{ - Json, - extract::State, - response::{ - IntoResponse, Response, - sse::{Event as SseEvent, Sse}, - }, -}; -use fastrace::prelude::{Event as TraceEvent, *}; -use log::error; +use axum::response::sse::Event as SseEvent; +use fastrace::Span; use opentelemetry_semantic_conventions::attribute::GEN_AI_RESPONSE_FINISH_REASONS; +use reqwest::Url; use span_attributes::{ - StreamOutputCollector, apply_span_properties, chunk_span_properties, request_span_properties, - response_span_properties, usage_span_properties, + StreamOutputCollector, chunk_span_properties, request_span_properties, response_span_properties, }; -use tokio::sync::{oneshot, oneshot::error::TryRecvError}; pub use types::ChatCompletionError; use crate::{ - config::entities::{Model, ResourceEntry}, gateway::{ - error::GatewayError, formats::OpenAIChatFormat, - traits::ChatFormat, + traits::ProviderCapabilities, types::{ common::Usage, - openai::{ChatCompletionRequest, ChatCompletionResponse}, - response::{ChatResponse, ChatResponseStream}, + openai::{ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse}, }, }, - proxy::{ - AppState, - hooks::{self, RequestContext}, - provider::create_provider_instance, - }, - utils::future::{WithSpan, maybe_timeout}, + proxy::handlers::format_handler::FormatHandlerAdapter, }; -pub async fn chat_completions( - State(state): State, - mut request_ctx: RequestContext, - Json(mut request_data): Json, -) -> Result { - hooks::observability::record_start_time(&mut request_ctx).await; - hooks::authorization::check( - &mut request_ctx, - OpenAIChatFormat::extract_model(&request_data).to_owned(), - ) - .await?; - hooks::rate_limit::pre_check(&mut request_ctx).await?; - - let model = request_ctx - .extensions() - .await - .get::>() - .cloned() - .ok_or(ChatCompletionError::MissingModelInContext)?; - - // Replace request model name with real model name - request_data.model = model.model.clone(); - let timeout = model.timeout.map(Duration::from_millis); - - let gateway = state.gateway(); - let resources = state.resources(); - let provider = model.provider(resources.as_ref()).ok_or_else(|| { - GatewayError::Internal(format!("provider {} not found", model.provider_id)) - })?; - let provider_instance = create_provider_instance(gateway.as_ref(), &provider)?; - let provider_base_url = provider_instance.effective_base_url().ok(); +pub(crate) struct ChatCompletionsAdapter; - let span = Span::enter_with_local_parent("aisix.llm.chat_completions"); - apply_span_properties( - &span, - request_span_properties( - &request_data, - provider_instance.def.as_ref(), - provider_base_url.as_ref(), - ), - ); +impl FormatHandlerAdapter for ChatCompletionsAdapter { + type Format = OpenAIChatFormat; + type Request = ChatCompletionRequest; + type Response = ChatCompletionResponse; + type StreamChunk = ChatCompletionChunk; + type Error = ChatCompletionError; + type Collector = StreamOutputCollector; - let (response, span) = (WithSpan { - inner: maybe_timeout( - timeout, - gateway.chat_completion(&request_data, &provider_instance), - ), - span: Some(span), - }) - .await; - - match response { - Ok(Ok(ChatResponse::Complete { response, usage })) => { - span.add_properties(|| response_span_properties(&response, &usage)); - handle_regular_request(response, usage, &mut request_ctx).await - } - Ok(Ok(ChatResponse::Stream { stream, usage_rx })) => { - handle_stream_request(stream, usage_rx, &mut request_ctx, span).await - } - Ok(Err(err)) => { - span.add_property(|| ("error.type", "gateway_error")); - Err(err.into()) - } - Err(err) => { - span.add_property(|| ("error.type", "timeout")); - Err(ChatCompletionError::Timeout(err)) - } + fn span_name() -> &'static str { + "aisix.llm.chat_completions" } -} -async fn handle_regular_request( - response: ChatCompletionResponse, - usage: Usage, - request_ctx: &mut RequestContext, -) -> Result { - if let Err(err) = hooks::rate_limit::post_check(request_ctx, &usage).await { - error!("Rate limit post_check error: {}", err); + fn missing_model_error() -> Self::Error { + ChatCompletionError::MissingModelInContext } - let mut resp = Json(response).into_response(); - hooks::rate_limit::inject_response_headers(request_ctx, resp.headers_mut()).await; - hooks::observability::record_usage(request_ctx, &usage).await; + fn set_model(request: &mut Self::Request, model: String) { + request.model = model; + } - Ok(resp) -} + fn request_span_properties( + request: &Self::Request, + provider: &dyn ProviderCapabilities, + base_url: Option<&Url>, + ) -> Vec<(String, String)> { + request_span_properties(request, provider, base_url) + } -fn spawn_stream_usage_observer(request_ctx: RequestContext, usage_rx: oneshot::Receiver) { - tokio::spawn(async move { - let mut request_ctx = request_ctx; + fn response_span_properties(response: &Self::Response, usage: &Usage) -> Vec<(String, String)> { + response_span_properties(response, usage) + } - match usage_rx.await { - Ok(usage) => { - if let Err(err) = - hooks::rate_limit::post_check_streaming(&mut request_ctx, &usage).await - { - error!("Rate limit post_check_streaming error: {}", err); - } - hooks::observability::record_streaming_usage(&mut request_ctx, &usage).await; - } - Err(err) => { - error!("Failed to receive streaming usage from gateway: {}", err); - } + fn apply_chunk_span_properties(span: &Span, chunk: &Self::StreamChunk, is_first_item: bool) { + if is_first_item { + span.add_properties(|| chunk_span_properties(chunk)); + return; } - }); -} - -async fn handle_stream_request( - stream: ChatResponseStream, - usage_rx: oneshot::Receiver, - request_ctx: &mut RequestContext, - span: Span, -) -> Result { - use futures::stream::StreamExt; - - let stream_request_ctx = request_ctx.clone(); - let sse_stream = futures::stream::unfold( - ( - stream, - span, - 0usize, - stream_request_ctx, - false, - false, - Some(usage_rx), - StreamOutputCollector::default(), - ), - |( - mut stream, - span, - idx, - mut request_ctx, - done, - saw_chunk, - mut usage_rx, - mut output_collector, - )| async move { - if done { - drop(span); - return None; - } - - match stream.next().await { - Some(Ok(chunk)) => { - output_collector.record_chunk(&chunk); - if idx == 0 { - hooks::observability::record_first_token_latency(&mut request_ctx).await; - span.add_event( - TraceEvent::new("first token arrived") - .with_property(|| ("kind", "first_token_arrived")), - ); - span.add_properties(|| chunk_span_properties(&chunk)); - } else { - let properties = chunk_span_properties(&chunk); - properties - .iter() - .filter(|(key, _)| { - key == GEN_AI_RESPONSE_FINISH_REASONS - || key == "llm.finish_reason" - || key == "llm.token_count.completion_details.reasoning" - }) - .for_each(|item| span.add_property(|| item.clone())); - } - - let mut event = - SseEvent::default().data(OpenAIChatFormat::serialize_chunk_payload(&chunk)); - if let Some(event_type) = OpenAIChatFormat::sse_event_type(&chunk) { - event = event.event(event_type); - } - let event = Ok::(event); + let properties = chunk_span_properties(chunk); + properties + .iter() + .filter(|(key, _)| { + key == GEN_AI_RESPONSE_FINISH_REASONS + || key == "llm.finish_reason" + || key == "llm.token_count.completion_details.reasoning" + }) + .for_each(|item| span.add_property(|| item.clone())); + } - Some(( - event, - ( - stream, - span, - idx + 1, - request_ctx, - false, - true, - usage_rx, - output_collector, - ), - )) - } - Some(Err(err)) => { - error!("Gateway stream error: {}", err); - span.add_property(|| ("error.type", "stream_error")); - span.add_properties(|| output_collector.output_message_span_properties()); - if let Some(usage_rx) = usage_rx.take() { - spawn_stream_usage_observer(request_ctx.clone(), usage_rx); - } - drop(span); - None - } - None => { - span.add_properties(|| output_collector.output_message_span_properties()); + fn starts_output(_chunk: &Self::StreamChunk) -> bool { + true + } - if let Some(mut usage_rx) = usage_rx.take() { - match usage_rx.try_recv() { - Ok(usage) => { - if let Err(err) = hooks::rate_limit::post_check_streaming( - &mut request_ctx, - &usage, - ) - .await - { - error!("Rate limit post_check_streaming error: {}", err); - } - hooks::observability::record_streaming_usage( - &mut request_ctx, - &usage, - ) - .await; - span.add_properties(|| usage_span_properties(&usage)); - } - Err(TryRecvError::Empty) => { - spawn_stream_usage_observer(request_ctx.clone(), usage_rx); - } - Err(TryRecvError::Closed) => { - error!( - "Failed to receive streaming usage from gateway: channel closed" - ); - } - } - } + fn record_stream_item(collector: &mut Self::Collector, chunk: &Self::StreamChunk) { + collector.record_chunk(chunk); + } - if saw_chunk { - Some(( - Ok(SseEvent::default().data("[DONE]")), - ( - stream, - span, - idx + 1, - request_ctx, - true, - saw_chunk, - usage_rx, - output_collector, - ), - )) - } else { - drop(span); - None - } - } - } - }, - ); + fn output_message_span_properties(collector: &Self::Collector) -> Vec<(String, String)> { + collector.output_message_span_properties() + } - let mut response = Sse::new(sse_stream).into_response(); - hooks::rate_limit::inject_response_headers(request_ctx, response.headers_mut()).await; - Ok(response) + fn end_of_stream_event(saw_item: bool) -> Option { + saw_item.then(|| SseEvent::default().data("[DONE]")) + } } diff --git a/src/proxy/handlers/chat_completions/span_attributes/mod.rs b/src/proxy/handlers/chat_completions/span_attributes/mod.rs index 4ee6282..43d442d 100644 --- a/src/proxy/handlers/chat_completions/span_attributes/mod.rs +++ b/src/proxy/handlers/chat_completions/span_attributes/mod.rs @@ -7,9 +7,5 @@ pub(super) use telemetry::{ chunk_span_properties, request_span_properties, response_span_properties, }; -pub(super) use crate::proxy::utils::trace::span_attributes::{ - apply_span_properties, usage_span_properties, -}; - #[cfg(test)] mod tests; diff --git a/src/proxy/handlers/chat_completions/span_attributes/stream_output.rs b/src/proxy/handlers/chat_completions/span_attributes/stream_output.rs index a51360d..3ed883c 100644 --- a/src/proxy/handlers/chat_completions/span_attributes/stream_output.rs +++ b/src/proxy/handlers/chat_completions/span_attributes/stream_output.rs @@ -22,15 +22,12 @@ struct StreamOutputChoice { } #[derive(Default)] -pub(in crate::proxy::handlers::chat_completions) struct StreamOutputCollector { +pub(crate) struct StreamOutputCollector { choices: BTreeMap, } impl StreamOutputCollector { - pub(in crate::proxy::handlers::chat_completions) fn record_chunk( - &mut self, - chunk: &ChatCompletionChunk, - ) { + pub(crate) fn record_chunk(&mut self, chunk: &ChatCompletionChunk) { for choice in &chunk.choices { let output_choice = self.choices.entry(choice.index).or_default(); @@ -69,9 +66,7 @@ impl StreamOutputCollector { } } - pub(in crate::proxy::handlers::chat_completions) fn output_message_span_properties( - &self, - ) -> Vec<(String, String)> { + pub(crate) fn output_message_span_properties(&self) -> Vec<(String, String)> { output_message_span_properties(&self.output_message_views()) } diff --git a/src/proxy/handlers/format_handler.rs b/src/proxy/handlers/format_handler.rs new file mode 100644 index 0000000..d977a69 --- /dev/null +++ b/src/proxy/handlers/format_handler.rs @@ -0,0 +1,376 @@ +use std::{convert::Infallible, time::Duration}; + +use axum::{ + Json, + extract::State, + response::{ + IntoResponse, Response, + sse::{Event as SseEvent, Sse}, + }, +}; +use fastrace::prelude::{Event as TraceEvent, Span}; +use log::error; +use reqwest::Url; +use serde::Serialize; +use tokio::{ + sync::{oneshot, oneshot::error::TryRecvError}, + time::error::Elapsed, +}; + +use crate::{ + config::entities::{Model, ResourceEntry}, + gateway::{ + error::GatewayError, + traits::{ChatFormat, ProviderCapabilities}, + types::{ + common::Usage, + response::{ChatResponse, ChatResponseStream}, + }, + }, + proxy::{ + AppState, + hooks::{ + self, RequestContext, authorization::AuthorizationError, rate_limit::RateLimitError, + }, + provider::create_provider_instance, + utils::trace::span_attributes::{apply_span_properties, usage_span_properties}, + }, + utils::future::{WithSpan, maybe_timeout}, +}; + +type AdapterFormat = ::Format; +type AdapterRequest = ::Request; +type AdapterResponse = ::Response; +type AdapterCollector = ::Collector; + +pub(crate) trait FormatHandlerAdapter: Send + Sync + 'static { + type Format: ChatFormat< + Request = Self::Request, + Response = Self::Response, + StreamChunk = Self::StreamChunk, + >; + type Request; + type Response: Serialize; + type StreamChunk: Serialize + Send + 'static; + type Error: IntoResponse + + From + + From + + From + + From; + type Collector: Default + Send + 'static; + + fn span_name() -> &'static str; + + fn missing_model_error() -> Self::Error; + + fn set_model(request: &mut Self::Request, model: String); + + fn request_span_properties( + request: &Self::Request, + provider: &dyn ProviderCapabilities, + base_url: Option<&Url>, + ) -> Vec<(String, String)>; + + fn response_span_properties(response: &Self::Response, usage: &Usage) -> Vec<(String, String)>; + + fn apply_chunk_span_properties(span: &Span, chunk: &Self::StreamChunk, is_first_item: bool); + + fn starts_output(chunk: &Self::StreamChunk) -> bool; + + fn record_stream_item(collector: &mut Self::Collector, chunk: &Self::StreamChunk); + + fn output_message_span_properties(collector: &Self::Collector) -> Vec<(String, String)>; + + fn serialize_stream_item(chunk: &Self::StreamChunk) -> SseEvent { + let mut event = + SseEvent::default().data(::serialize_chunk_payload(chunk)); + + if let Some(event_type) = ::sse_event_type(chunk) { + event = event.event(event_type); + } + + event + } + + fn stream_error_event(_error: &GatewayError) -> Option { + None + } + + fn end_of_stream_event(_saw_item: bool) -> Option { + None + } +} + +pub(crate) async fn format_handler( + State(state): State, + mut request_ctx: RequestContext, + Json(mut request_data): Json>, +) -> Result +where + A: FormatHandlerAdapter, +{ + hooks::observability::record_start_time(&mut request_ctx).await; + hooks::authorization::check( + &mut request_ctx, + as ChatFormat>::extract_model(&request_data).to_owned(), + ) + .await?; + hooks::rate_limit::pre_check(&mut request_ctx).await?; + + let model = request_ctx + .extensions() + .await + .get::>() + .cloned() + .ok_or_else(A::missing_model_error)?; + + A::set_model(&mut request_data, model.model.clone()); + let timeout = model.timeout.map(Duration::from_millis); + + let gateway = state.gateway(); + let resources = state.resources(); + let provider = model.provider(resources.as_ref()).ok_or_else(|| { + GatewayError::Internal(format!("provider {} not found", model.provider_id)) + })?; + let provider_instance = create_provider_instance(gateway.as_ref(), &provider)?; + let provider_base_url = provider_instance.effective_base_url().ok(); + + let span = Span::enter_with_local_parent(A::span_name()); + apply_span_properties( + &span, + A::request_span_properties( + &request_data, + provider_instance.def.as_ref(), + provider_base_url.as_ref(), + ), + ); + + let (response, span) = (WithSpan { + inner: maybe_timeout( + timeout, + gateway.chat::>(&request_data, &provider_instance), + ), + span: Some(span), + }) + .await; + + match response { + Ok(Ok(ChatResponse::Complete { response, usage })) => { + span.add_properties(|| A::response_span_properties(&response, &usage)); + handle_regular_response::(response, usage, &mut request_ctx).await + } + Ok(Ok(ChatResponse::Stream { stream, usage_rx })) => { + handle_stream_response::(stream, usage_rx, &mut request_ctx, span).await + } + Ok(Err(err)) => { + span.add_property(|| ("error.type", "gateway_error")); + Err(err.into()) + } + Err(err) => { + span.add_property(|| ("error.type", "timeout")); + Err(err.into()) + } + } +} + +async fn handle_regular_response( + response: AdapterResponse, + usage: Usage, + request_ctx: &mut RequestContext, +) -> Result +where + A: FormatHandlerAdapter, +{ + if let Err(err) = hooks::rate_limit::post_check(request_ctx, &usage).await { + error!("Rate limit post_check error: {}", err); + } + + let mut response = Json(response).into_response(); + hooks::rate_limit::inject_response_headers(request_ctx, response.headers_mut()).await; + hooks::observability::record_usage(request_ctx, &usage).await; + + Ok(response) +} + +fn spawn_stream_usage_observer(request_ctx: RequestContext, usage_rx: oneshot::Receiver) { + tokio::spawn(async move { + let mut request_ctx = request_ctx; + + match usage_rx.await { + Ok(usage) => { + if let Err(err) = + hooks::rate_limit::post_check_streaming(&mut request_ctx, &usage).await + { + error!("Rate limit post_check_streaming error: {}", err); + } + hooks::observability::record_streaming_usage(&mut request_ctx, &usage).await; + } + Err(err) => { + error!("Failed to receive streaming usage from gateway: {}", err); + } + } + }); +} + +async fn finalize_stream_usage( + request_ctx: &mut RequestContext, + usage_rx: &mut Option>, + span: &Span, + output_message_properties: Vec<(String, String)>, +) { + span.add_properties(|| output_message_properties); + + if let Some(mut usage_rx) = usage_rx.take() { + match usage_rx.try_recv() { + Ok(usage) => { + if let Err(err) = hooks::rate_limit::post_check_streaming(request_ctx, &usage).await + { + error!("Rate limit post_check_streaming error: {}", err); + } + hooks::observability::record_streaming_usage(request_ctx, &usage).await; + span.add_properties(|| usage_span_properties(&usage)); + } + Err(TryRecvError::Empty) => { + spawn_stream_usage_observer(request_ctx.clone(), usage_rx); + } + Err(TryRecvError::Closed) => { + error!("Failed to receive streaming usage from gateway: channel closed"); + } + } + } +} + +async fn handle_stream_response( + stream: ChatResponseStream>, + usage_rx: oneshot::Receiver, + request_ctx: &mut RequestContext, + span: Span, +) -> Result +where + A: FormatHandlerAdapter, +{ + use futures::stream::StreamExt; + + let stream_request_ctx = request_ctx.clone(); + let sse_stream = futures::stream::unfold( + ( + stream, + span, + stream_request_ctx, + false, + false, + Some(usage_rx), + AdapterCollector::::default(), + false, + ), + |( + mut stream, + span, + mut request_ctx, + should_terminate, + saw_item, + mut usage_rx, + mut output_collector, + mut first_output_arrived, + )| async move { + if should_terminate { + drop(span); + return None; + } + + match stream.next().await { + Some(Ok(chunk)) => { + A::record_stream_item(&mut output_collector, &chunk); + + let now_starts_output = !first_output_arrived && A::starts_output(&chunk); + if now_starts_output { + first_output_arrived = true; + hooks::observability::record_first_token_latency(&mut request_ctx).await; + span.add_event( + TraceEvent::new("first token arrived") + .with_property(|| ("kind", "first_token_arrived")), + ); + } + + A::apply_chunk_span_properties(&span, &chunk, !saw_item); + + Some(( + Ok::(A::serialize_stream_item(&chunk)), + ( + stream, + span, + request_ctx, + false, + true, + usage_rx, + output_collector, + first_output_arrived, + ), + )) + } + Some(Err(err)) => { + error!("Gateway stream error: {}", err); + span.add_property(|| ("error.type", "stream_error")); + finalize_stream_usage( + &mut request_ctx, + &mut usage_rx, + &span, + A::output_message_span_properties(&output_collector), + ) + .await; + + if let Some(event) = A::stream_error_event(&err) { + Some(( + Ok(event), + ( + stream, + span, + request_ctx, + true, + saw_item, + usage_rx, + output_collector, + first_output_arrived, + ), + )) + } else { + drop(span); + None + } + } + None => { + finalize_stream_usage( + &mut request_ctx, + &mut usage_rx, + &span, + A::output_message_span_properties(&output_collector), + ) + .await; + + if let Some(event) = A::end_of_stream_event(saw_item) { + Some(( + Ok(event), + ( + stream, + span, + request_ctx, + true, + saw_item, + usage_rx, + output_collector, + first_output_arrived, + ), + )) + } else { + drop(span); + None + } + } + } + }, + ); + + let mut response = Sse::new(sse_stream).into_response(); + hooks::rate_limit::inject_response_headers(request_ctx, response.headers_mut()).await; + Ok(response) +} diff --git a/src/proxy/handlers/messages/mod.rs b/src/proxy/handlers/messages/mod.rs index 565a37d..23cbe05 100644 --- a/src/proxy/handlers/messages/mod.rs +++ b/src/proxy/handlers/messages/mod.rs @@ -1,288 +1,29 @@ mod span_attributes; mod types; -use std::{convert::Infallible, time::Duration}; - -use axum::{ - Json, - extract::State, - response::{ - IntoResponse, Response, - sse::{Event as SseEvent, Sse}, - }, -}; -use fastrace::prelude::{Event as TraceEvent, *}; -use log::error; +use axum::response::sse::Event as SseEvent; +use fastrace::Span; +use reqwest::Url; use span_attributes::{ - StreamOutputCollector, apply_span_properties, chunk_span_properties, request_span_properties, - response_span_properties, usage_span_properties, + StreamOutputCollector, chunk_span_properties, request_span_properties, response_span_properties, }; -use tokio::sync::{oneshot, oneshot::error::TryRecvError}; pub use types::MessagesError; use crate::{ - config::entities::{Model, ResourceEntry}, gateway::{ error::GatewayError, formats::AnthropicMessagesFormat, - traits::ChatFormat, + traits::{ChatFormat, ProviderCapabilities}, types::{ anthropic::{ AnthropicMessagesRequest, AnthropicMessagesResponse, AnthropicStreamEvent, }, common::Usage, - response::{ChatResponse, ChatResponseStream}, }, }, - proxy::{ - AppState, - hooks::{self, RequestContext}, - provider::create_provider_instance, - }, - utils::future::{WithSpan, maybe_timeout}, + proxy::handlers::format_handler::FormatHandlerAdapter, }; -/// Handles Anthropic Messages API requests on `/v1/messages`. -/// -/// Accepts an `AnthropicMessagesRequest` JSON body, applies proxy auth and hook -/// context from `AppState`, `SpanContext`, and `RequestContext`, and returns -/// either a complete Anthropic Messages JSON response or an SSE stream of -/// Anthropic stream events. Failures are mapped into `MessagesError`. -pub async fn messages( - State(state): State, - mut request_ctx: RequestContext, - Json(mut request_data): Json, -) -> Result { - hooks::observability::record_start_time(&mut request_ctx).await; - hooks::authorization::check( - &mut request_ctx, - AnthropicMessagesFormat::extract_model(&request_data).to_owned(), - ) - .await?; - hooks::rate_limit::pre_check(&mut request_ctx).await?; - - let model = request_ctx - .extensions() - .await - .get::>() - .cloned() - .ok_or(MessagesError::MissingModelInContext)?; - - request_data.model = model.model.clone(); - let timeout = model.timeout.map(Duration::from_millis); - - let gateway = state.gateway(); - let resources = state.resources(); - let provider = model.provider(resources.as_ref()).ok_or_else(|| { - GatewayError::Internal(format!("provider {} not found", model.provider_id)) - })?; - let provider_instance = create_provider_instance(gateway.as_ref(), &provider)?; - let provider_base_url = provider_instance.effective_base_url().ok(); - - let span = Span::enter_with_local_parent("aisix.llm.messages"); - apply_span_properties( - &span, - request_span_properties( - &request_data, - provider_instance.def.as_ref(), - provider_base_url.as_ref(), - ), - ); - - let (response, span) = (WithSpan { - inner: maybe_timeout(timeout, gateway.messages(&request_data, &provider_instance)), - span: Some(span), - }) - .await; - - match response { - Ok(Ok(ChatResponse::Complete { response, usage })) => { - span.add_properties(|| response_span_properties(&response, &usage)); - handle_regular_request(response, usage, &mut request_ctx).await - } - Ok(Ok(ChatResponse::Stream { stream, usage_rx })) => { - handle_stream_request(stream, usage_rx, &mut request_ctx, span).await - } - Ok(Err(err)) => { - span.add_property(|| ("error.type", "gateway_error")); - Err(err.into()) - } - Err(err) => { - span.add_property(|| ("error.type", "timeout")); - Err(MessagesError::Timeout(err)) - } - } -} - -async fn handle_regular_request( - response: AnthropicMessagesResponse, - usage: Usage, - request_ctx: &mut RequestContext, -) -> Result { - if let Err(err) = hooks::rate_limit::post_check(request_ctx, &usage).await { - error!("Rate limit post_check error: {}", err); - } - - let mut resp = Json(response).into_response(); - hooks::rate_limit::inject_response_headers(request_ctx, resp.headers_mut()).await; - hooks::observability::record_usage(request_ctx, &usage).await; - - Ok(resp) -} - -fn spawn_stream_usage_observer(request_ctx: RequestContext, usage_rx: oneshot::Receiver) { - tokio::spawn(async move { - let mut request_ctx = request_ctx; - - match usage_rx.await { - Ok(usage) => { - if let Err(err) = - hooks::rate_limit::post_check_streaming(&mut request_ctx, &usage).await - { - error!("Rate limit post_check_streaming error: {}", err); - } - hooks::observability::record_streaming_usage(&mut request_ctx, &usage).await; - } - Err(err) => { - error!("Failed to receive streaming usage from gateway: {}", err); - } - } - }); -} - -async fn handle_stream_request( - stream: ChatResponseStream, - usage_rx: oneshot::Receiver, - request_ctx: &mut RequestContext, - span: Span, -) -> Result { - use futures::stream::StreamExt; - - let stream_request_ctx = request_ctx.clone(); - let sse_stream = futures::stream::unfold( - ( - stream, - span, - 0usize, - stream_request_ctx, - false, - Some(usage_rx), - StreamOutputCollector::default(), - false, - ), - |( - mut stream, - span, - idx, - mut request_ctx, - should_terminate, - mut usage_rx, - mut output_collector, - mut first_token_arrived, - )| async move { - if should_terminate { - drop(span); - return None; - } - - match stream.next().await { - Some(Ok(event)) => { - output_collector.record_event(&event); - - if let AnthropicStreamEvent::ContentBlockStart { .. } = event - && !first_token_arrived - { - first_token_arrived = true; - hooks::observability::record_first_token_latency(&mut request_ctx).await; - span.add_event( - TraceEvent::new("first token arrived") - .with_property(|| ("kind", "first_token_arrived")), - ); - } - - span.add_properties(|| chunk_span_properties(&event)); - - let sse_event = Ok::(serialize_stream_event(&event)); - - Some(( - sse_event, - ( - stream, - span, - idx + 1, - request_ctx, - false, - usage_rx, - output_collector, - first_token_arrived, - ), - )) - } - Some(Err(err)) => { - error!("Gateway stream error: {}", err); - span.add_property(|| ("error.type", "stream_error")); - span.add_properties(|| output_collector.output_message_span_properties()); - if let Some(usage_rx) = usage_rx.take() { - spawn_stream_usage_observer(request_ctx.clone(), usage_rx); - } - Some(( - Ok(anthropic_error_sse_event(err.to_string())), - ( - stream, - span, - idx + 1, - request_ctx, - true, - usage_rx, - output_collector, - first_token_arrived, - ), - )) - } - None => { - span.add_properties(|| output_collector.output_message_span_properties()); - - if let Some(mut usage_rx) = usage_rx.take() { - match usage_rx.try_recv() { - Ok(usage) => { - if let Err(err) = hooks::rate_limit::post_check_streaming( - &mut request_ctx, - &usage, - ) - .await - { - error!("Rate limit post_check_streaming error: {}", err); - } - hooks::observability::record_streaming_usage( - &mut request_ctx, - &usage, - ) - .await; - span.add_properties(|| usage_span_properties(&usage)); - } - Err(TryRecvError::Empty) => { - spawn_stream_usage_observer(request_ctx.clone(), usage_rx); - } - Err(TryRecvError::Closed) => { - error!( - "Failed to receive streaming usage from gateway: channel closed" - ); - } - } - } - - drop(span); - None - } - } - }, - ); - - let mut response = Sse::new(sse_stream).into_response(); - hooks::rate_limit::inject_response_headers(request_ctx, response.headers_mut()).await; - Ok(response) -} - fn anthropic_error_sse_event(message: String) -> SseEvent { SseEvent::default() .event("error") @@ -310,3 +51,62 @@ fn serialize_stream_event(event: &AnthropicStreamEvent) -> SseEvent { sse_event } + +pub(crate) struct MessagesAdapter; + +impl FormatHandlerAdapter for MessagesAdapter { + type Format = AnthropicMessagesFormat; + type Request = AnthropicMessagesRequest; + type Response = AnthropicMessagesResponse; + type StreamChunk = AnthropicStreamEvent; + type Error = MessagesError; + type Collector = StreamOutputCollector; + + fn span_name() -> &'static str { + "aisix.llm.messages" + } + + fn missing_model_error() -> Self::Error { + MessagesError::MissingModelInContext + } + + fn set_model(request: &mut Self::Request, model: String) { + request.model = model; + } + + fn request_span_properties( + request: &Self::Request, + provider: &dyn ProviderCapabilities, + base_url: Option<&Url>, + ) -> Vec<(String, String)> { + request_span_properties(request, provider, base_url) + } + + fn response_span_properties(response: &Self::Response, usage: &Usage) -> Vec<(String, String)> { + response_span_properties(response, usage) + } + + fn apply_chunk_span_properties(span: &Span, chunk: &Self::StreamChunk, _is_first_item: bool) { + span.add_properties(|| chunk_span_properties(chunk)); + } + + fn starts_output(chunk: &Self::StreamChunk) -> bool { + matches!(chunk, AnthropicStreamEvent::ContentBlockStart { .. }) + } + + fn record_stream_item(collector: &mut Self::Collector, chunk: &Self::StreamChunk) { + collector.record_event(chunk); + } + + fn output_message_span_properties(collector: &Self::Collector) -> Vec<(String, String)> { + collector.output_message_span_properties() + } + + fn serialize_stream_item(chunk: &Self::StreamChunk) -> SseEvent { + serialize_stream_event(chunk) + } + + fn stream_error_event(error: &GatewayError) -> Option { + Some(anthropic_error_sse_event(error.to_string())) + } +} diff --git a/src/proxy/handlers/messages/span_attributes/mod.rs b/src/proxy/handlers/messages/span_attributes/mod.rs index 4ee6282..43d442d 100644 --- a/src/proxy/handlers/messages/span_attributes/mod.rs +++ b/src/proxy/handlers/messages/span_attributes/mod.rs @@ -7,9 +7,5 @@ pub(super) use telemetry::{ chunk_span_properties, request_span_properties, response_span_properties, }; -pub(super) use crate::proxy::utils::trace::span_attributes::{ - apply_span_properties, usage_span_properties, -}; - #[cfg(test)] mod tests; diff --git a/src/proxy/handlers/messages/span_attributes/stream_output.rs b/src/proxy/handlers/messages/span_attributes/stream_output.rs index f0c701a..ba3c8e5 100644 --- a/src/proxy/handlers/messages/span_attributes/stream_output.rs +++ b/src/proxy/handlers/messages/span_attributes/stream_output.rs @@ -22,17 +22,14 @@ enum StreamOutputBlock { } #[derive(Default)] -pub(in crate::proxy::handlers::messages) struct StreamOutputCollector { +pub(crate) struct StreamOutputCollector { role: Option, blocks: BTreeMap, finish_reason: Option, } impl StreamOutputCollector { - pub(in crate::proxy::handlers::messages) fn record_event( - &mut self, - event: &AnthropicStreamEvent, - ) { + pub(crate) fn record_event(&mut self, event: &AnthropicStreamEvent) { match event { AnthropicStreamEvent::MessageStart { message } => { self.role = Some(message.role.clone()); @@ -104,9 +101,7 @@ impl StreamOutputCollector { } } - pub(in crate::proxy::handlers::messages) fn output_message_span_properties( - &self, - ) -> Vec<(String, String)> { + pub(crate) fn output_message_span_properties(&self) -> Vec<(String, String)> { output_message_span_properties(&self.output_message_views()) } diff --git a/src/proxy/handlers/mod.rs b/src/proxy/handlers/mod.rs index 928b21a..be5eb96 100644 --- a/src/proxy/handlers/mod.rs +++ b/src/proxy/handlers/mod.rs @@ -1,5 +1,6 @@ pub mod chat_completions; pub mod embeddings; +pub mod format_handler; pub mod messages; pub mod models; pub mod responses; diff --git a/src/proxy/handlers/responses/mod.rs b/src/proxy/handlers/responses/mod.rs index 7609dbc..b8d39a1 100644 --- a/src/proxy/handlers/responses/mod.rs +++ b/src/proxy/handlers/responses/mod.rs @@ -1,315 +1,98 @@ mod span_attributes; mod types; -use std::{convert::Infallible, time::Duration}; - -use axum::{ - Json, - extract::State, - response::{ - IntoResponse, Response, - sse::{Event as SseEvent, Sse}, - }, -}; -use fastrace::prelude::{Event as TraceEvent, *}; -use log::error; +use axum::response::sse::Event as SseEvent; +use fastrace::Span; +use reqwest::Url; use span_attributes::{ - StreamOutputCollector, apply_span_properties, chunk_span_properties, event_starts_output, - request_span_properties, response_span_properties, usage_span_properties, + StreamOutputCollector, chunk_span_properties, event_starts_output, request_span_properties, + response_span_properties, }; -use tokio::sync::{oneshot, oneshot::error::TryRecvError}; pub use types::ResponsesError; use crate::{ - config::entities::{Model, ResourceEntry}, gateway::{ error::GatewayError, formats::ResponsesApiFormat, - traits::ChatFormat, + traits::{ChatFormat, ProviderCapabilities}, types::{ common::Usage, openai::responses::{ ResponsesApiRequest, ResponsesApiResponse, ResponsesApiStreamEvent, }, - response::{ChatResponse, ChatResponseStream}, }, }, - proxy::{ - AppState, - hooks::{self, RequestContext}, - provider::create_provider_instance, - }, - utils::future::{WithSpan, maybe_timeout}, + proxy::handlers::format_handler::FormatHandlerAdapter, }; -pub async fn responses( - State(state): State, - mut request_ctx: RequestContext, - Json(mut request_data): Json, -) -> Result { - hooks::observability::record_start_time(&mut request_ctx).await; - hooks::authorization::check( - &mut request_ctx, - ResponsesApiFormat::extract_model(&request_data).to_owned(), - ) - .await?; - hooks::rate_limit::pre_check(&mut request_ctx).await?; - - let model = request_ctx - .extensions() - .await - .get::>() - .cloned() - .ok_or(ResponsesError::MissingModelInContext)?; - - request_data.model = model.model.clone(); - let timeout = model.timeout.map(Duration::from_millis); - - let gateway = state.gateway(); - let resources = state.resources(); - let provider = model.provider(resources.as_ref()).ok_or_else(|| { - GatewayError::Internal(format!("provider {} not found", model.provider_id)) - })?; - let provider_instance = create_provider_instance(gateway.as_ref(), &provider)?; - let provider_base_url = provider_instance.effective_base_url().ok(); - - let span = Span::enter_with_local_parent("aisix.llm.responses"); - apply_span_properties( - &span, - request_span_properties( - &request_data, - provider_instance.def.as_ref(), - provider_base_url.as_ref(), - ), - ); - - let (response, span) = (WithSpan { - inner: maybe_timeout( - timeout, - gateway.chat::(&request_data, &provider_instance), - ), - span: Some(span), - }) - .await; - - match response { - Ok(Ok(ChatResponse::Complete { response, usage })) => { - span.add_properties(|| response_span_properties(&response, &usage)); - handle_regular_request(response, usage, &mut request_ctx).await - } - Ok(Ok(ChatResponse::Stream { stream, usage_rx })) => { - handle_stream_request(stream, usage_rx, &mut request_ctx, span).await - } - Ok(Err(err)) => { - span.add_property(|| ("error.type", "gateway_error")); - Err(err.into()) - } - Err(err) => { - span.add_property(|| ("error.type", "timeout")); - Err(ResponsesError::Timeout(err)) - } - } -} +fn serialize_stream_event(event: &ResponsesApiStreamEvent) -> SseEvent { + let mut sse_event = + SseEvent::default().data(ResponsesApiFormat::serialize_chunk_payload(event)); -async fn handle_regular_request( - response: ResponsesApiResponse, - usage: Usage, - request_ctx: &mut RequestContext, -) -> Result { - if let Err(err) = hooks::rate_limit::post_check(request_ctx, &usage).await { - error!("Rate limit post_check error: {}", err); + if let Some(event_type) = ResponsesApiFormat::sse_event_type(event) { + sse_event = sse_event.event(event_type); } - let mut response = Json(response).into_response(); - hooks::rate_limit::inject_response_headers(request_ctx, response.headers_mut()).await; - hooks::observability::record_usage(request_ctx, &usage).await; - - Ok(response) -} - -fn spawn_stream_usage_observer(request_ctx: RequestContext, usage_rx: oneshot::Receiver) { - tokio::spawn(async move { - let mut request_ctx = request_ctx; - - match usage_rx.await { - Ok(usage) => { - if let Err(err) = - hooks::rate_limit::post_check_streaming(&mut request_ctx, &usage).await - { - error!("Rate limit post_check_streaming error: {}", err); - } - hooks::observability::record_streaming_usage(&mut request_ctx, &usage).await; - } - Err(err) => { - error!("Failed to receive streaming usage from gateway: {}", err); - } - } - }); + sse_event } -async fn handle_stream_request( - stream: ChatResponseStream, - usage_rx: oneshot::Receiver, - request_ctx: &mut RequestContext, - span: Span, -) -> Result { - use futures::stream::StreamExt; - - let stream_request_ctx = request_ctx.clone(); - let sse_stream = futures::stream::unfold( - ( - stream, - span, - stream_request_ctx, - false, - Some(usage_rx), - StreamOutputCollector::default(), - false, - ), - |( - mut stream, - span, - mut request_ctx, - should_terminate, - mut usage_rx, - mut output_collector, - mut first_token_arrived, - )| async move { - if should_terminate { - drop(span); - return None; - } - - match stream.next().await { - Some(Ok(event)) => { - output_collector.record_event(&event); +pub(crate) struct ResponsesAdapter; - if event_starts_output(&event) && !first_token_arrived { - first_token_arrived = true; - hooks::observability::record_first_token_latency(&mut request_ctx).await; - span.add_event( - TraceEvent::new("first token arrived") - .with_property(|| ("kind", "first_token_arrived")), - ); - } +impl FormatHandlerAdapter for ResponsesAdapter { + type Format = ResponsesApiFormat; + type Request = ResponsesApiRequest; + type Response = ResponsesApiResponse; + type StreamChunk = ResponsesApiStreamEvent; + type Error = ResponsesError; + type Collector = StreamOutputCollector; - span.add_properties(|| chunk_span_properties(&event)); + fn span_name() -> &'static str { + "aisix.llm.responses" + } - let sse_event = Ok::(serialize_stream_event(&event)); + fn missing_model_error() -> Self::Error { + ResponsesError::MissingModelInContext + } - Some(( - sse_event, - ( - stream, - span, - request_ctx, - false, - usage_rx, - output_collector, - first_token_arrived, - ), - )) - } - Some(Err(err)) => { - error!("Gateway stream error: {}", err); - span.add_property(|| ("error.type", "stream_error")); - span.add_properties(|| output_collector.output_message_span_properties()); + fn set_model(request: &mut Self::Request, model: String) { + request.model = model; + } - if let Some(mut usage_rx) = usage_rx.take() { - match usage_rx.try_recv() { - Ok(usage) => { - if let Err(err) = hooks::rate_limit::post_check_streaming( - &mut request_ctx, - &usage, - ) - .await - { - error!("Rate limit post_check_streaming error: {}", err); - } - hooks::observability::record_streaming_usage( - &mut request_ctx, - &usage, - ) - .await; - span.add_properties(|| usage_span_properties(&usage)); - } - Err(TryRecvError::Empty) => { - spawn_stream_usage_observer(request_ctx.clone(), usage_rx); - } - Err(TryRecvError::Closed) => { - error!( - "Failed to receive streaming usage from gateway: channel closed" - ); - } - } - } + fn request_span_properties( + request: &Self::Request, + provider: &dyn ProviderCapabilities, + base_url: Option<&Url>, + ) -> Vec<(String, String)> { + request_span_properties(request, provider, base_url) + } - Some(( - Ok(serialize_stream_event(&ResponsesApiStreamEvent::Error { - message: err.to_string(), - })), - ( - stream, - span, - request_ctx, - true, - usage_rx, - output_collector, - first_token_arrived, - ), - )) - } - None => { - span.add_properties(|| output_collector.output_message_span_properties()); + fn response_span_properties(response: &Self::Response, usage: &Usage) -> Vec<(String, String)> { + response_span_properties(response, usage) + } - if let Some(mut usage_rx) = usage_rx.take() { - match usage_rx.try_recv() { - Ok(usage) => { - if let Err(err) = hooks::rate_limit::post_check_streaming( - &mut request_ctx, - &usage, - ) - .await - { - error!("Rate limit post_check_streaming error: {}", err); - } - hooks::observability::record_streaming_usage( - &mut request_ctx, - &usage, - ) - .await; - span.add_properties(|| usage_span_properties(&usage)); - } - Err(TryRecvError::Empty) => { - spawn_stream_usage_observer(request_ctx.clone(), usage_rx); - } - Err(TryRecvError::Closed) => { - error!( - "Failed to receive streaming usage from gateway: channel closed" - ); - } - } - } + fn apply_chunk_span_properties(span: &Span, chunk: &Self::StreamChunk, _is_first_item: bool) { + span.add_properties(|| chunk_span_properties(chunk)); + } - drop(span); - None - } - } - }, - ); + fn starts_output(chunk: &Self::StreamChunk) -> bool { + event_starts_output(chunk) + } - let mut response = Sse::new(sse_stream).into_response(); - hooks::rate_limit::inject_response_headers(request_ctx, response.headers_mut()).await; - Ok(response) -} + fn record_stream_item(collector: &mut Self::Collector, chunk: &Self::StreamChunk) { + collector.record_event(chunk); + } -fn serialize_stream_event(event: &ResponsesApiStreamEvent) -> SseEvent { - let mut sse_event = - SseEvent::default().data(ResponsesApiFormat::serialize_chunk_payload(event)); + fn output_message_span_properties(collector: &Self::Collector) -> Vec<(String, String)> { + collector.output_message_span_properties() + } - if let Some(event_type) = ResponsesApiFormat::sse_event_type(event) { - sse_event = sse_event.event(event_type); + fn serialize_stream_item(chunk: &Self::StreamChunk) -> SseEvent { + serialize_stream_event(chunk) } - sse_event + fn stream_error_event(error: &GatewayError) -> Option { + Some(serialize_stream_event(&ResponsesApiStreamEvent::Error { + message: error.to_string(), + })) + } } diff --git a/src/proxy/handlers/responses/span_attributes/mod.rs b/src/proxy/handlers/responses/span_attributes/mod.rs index 4e24176..e401321 100644 --- a/src/proxy/handlers/responses/span_attributes/mod.rs +++ b/src/proxy/handlers/responses/span_attributes/mod.rs @@ -7,9 +7,5 @@ pub(super) use telemetry::{ chunk_span_properties, event_starts_output, request_span_properties, response_span_properties, }; -pub(super) use crate::proxy::utils::trace::span_attributes::{ - apply_span_properties, usage_span_properties, -}; - #[cfg(test)] mod tests; diff --git a/src/proxy/handlers/responses/span_attributes/stream_output.rs b/src/proxy/handlers/responses/span_attributes/stream_output.rs index 8abcdae..3953b97 100644 --- a/src/proxy/handlers/responses/span_attributes/stream_output.rs +++ b/src/proxy/handlers/responses/span_attributes/stream_output.rs @@ -9,16 +9,13 @@ use crate::{ }; #[derive(Default)] -pub(in crate::proxy::handlers::responses) struct StreamOutputCollector { +pub(crate) struct StreamOutputCollector { items: BTreeMap, completed_response: Option, } impl StreamOutputCollector { - pub(in crate::proxy::handlers::responses) fn record_event( - &mut self, - event: &ResponsesApiStreamEvent, - ) { + pub(crate) fn record_event(&mut self, event: &ResponsesApiStreamEvent) { match event { ResponsesApiStreamEvent::ResponseCreated { response } | ResponsesApiStreamEvent::ResponseInProgress { response } => { @@ -108,9 +105,7 @@ impl StreamOutputCollector { } } - pub(in crate::proxy::handlers::responses) fn output_message_span_properties( - &self, - ) -> Vec<(String, String)> { + pub(crate) fn output_message_span_properties(&self) -> Vec<(String, String)> { if let Some(response) = &self.completed_response { return output_message_span_properties(&output_message_views_from_output_items( &response.output, diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 01550ac..c461b92 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -53,13 +53,21 @@ pub fn create_router(state: AppState) -> Router { .merge(Router::new().route("/v1/models", get(handlers::models::list_models))) .route( "/v1/chat/completions", - post(handlers::chat_completions::chat_completions), + post( + handlers::format_handler::format_handler::< + handlers::chat_completions::ChatCompletionsAdapter, + >, + ), ) .route( "/v1/messages", - post(handlers::messages::messages).layer(DefaultBodyLimit::max(32 * 1024 * 1024)), + post(handlers::format_handler::format_handler::) + .layer(DefaultBodyLimit::max(32 * 1024 * 1024)), + ) + .route( + "/v1/responses", + post(handlers::format_handler::format_handler::), ) - .route("/v1/responses", post(handlers::responses::responses)) .route("/v1/embeddings", post(handlers::embeddings::embeddings)) .layer(DefaultBodyLimit::max(10 * 1024 * 1024)) .layer(from_fn_with_state(state.clone(), middlewares::auth))