diff --git a/crates/aisix-llm/src/formats/openai/responses.rs b/crates/aisix-llm/src/formats/openai/responses.rs index 64bc605..08ac912 100644 --- a/crates/aisix-llm/src/formats/openai/responses.rs +++ b/crates/aisix-llm/src/formats/openai/responses.rs @@ -84,6 +84,8 @@ impl ChatFormat for ResponsesApiFormat { }); } + messages.extend(req.replay_messages.iter().cloned()); + match &req.input { ResponsesInput::Text(text) => messages.push(ChatMessage { role: "user".into(), @@ -910,7 +912,7 @@ mod tests { types::{ common::{BridgeContext, OpenAIResponsesExtras}, openai::{ - ChatCompletionChunk, ChatCompletionResponse, + ChatCompletionChunk, ChatCompletionResponse, ChatMessage, MessageContent, responses::{ ResponsesApiRequest, ResponsesApiResponse, ResponsesApiStreamEvent, ResponsesOutputItem, @@ -1032,6 +1034,52 @@ mod tests { assert_eq!(extras.metadata.as_ref().unwrap()["request_id"], "req_1"); } + #[test] + fn to_hub_prepends_replay_messages_after_system_instructions() { + let mut request: ResponsesApiRequest = serde_json::from_value(json!({ + "model": "gpt-4.1", + "input": "current turn", + "instructions": "Be concise" + })) + .unwrap(); + request.replay_messages = vec![ + ChatMessage { + role: "user".into(), + content: Some(MessageContent::Text("previous user".into())), + name: None, + tool_calls: None, + tool_call_id: None, + }, + ChatMessage { + role: "assistant".into(), + content: Some(MessageContent::Text("previous assistant".into())), + name: None, + tool_calls: None, + tool_call_id: None, + }, + ]; + + let (hub, _) = ResponsesApiFormat::to_hub(&request).unwrap(); + + assert_eq!(hub.messages.len(), 4); + assert_eq!(hub.messages[0].role, "system"); + assert_eq!(hub.messages[1].role, "user"); + assert_eq!(hub.messages[2].role, "assistant"); + assert_eq!(hub.messages[3].role, "user"); + assert_matches!( + hub.messages[1].content.as_ref(), + Some(MessageContent::Text(text)) if text == "previous user" + ); + assert_matches!( + hub.messages[2].content.as_ref(), + Some(MessageContent::Text(text)) if text == "previous assistant" + ); + assert_matches!( + hub.messages[3].content.as_ref(), + Some(MessageContent::Text(text)) if text == "current turn" + ); + } + #[test] fn to_hub_rejects_builtin_tools_and_truncation_auto() { let built_in_tools_request: ResponsesApiRequest = serde_json::from_value(json!({ diff --git a/crates/aisix-llm/src/gateway.rs b/crates/aisix-llm/src/gateway.rs index 6ecb28e..f569db7 100644 --- a/crates/aisix-llm/src/gateway.rs +++ b/crates/aisix-llm/src/gateway.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, pin::Pin, sync::Arc, time::Duration}; +use std::{pin::Pin, time::Duration}; use bytes::Bytes; use futures::Stream; @@ -8,19 +8,15 @@ use tokio::sync::oneshot; use crate::{ error::{GatewayError, Result}, - formats::{AnthropicMessagesFormat, OpenAIChatFormat, ResponsesApiFormat}, + formats::{AnthropicMessagesFormat, OpenAIChatFormat}, provider_instance::{ProviderInstance, ProviderRegistry}, - session::{SessionStore, StoredSession}, streams::{BridgedStream, HubChunkStream, NativeStream, aws_event_stream_reader, sse_reader}, traits::{ChatFormat, NativeHandler, PreparedRequest, StreamReaderKind}, types::{ anthropic::AnthropicMessagesRequest, - common::{BridgeContext, OpenAIResponsesExtras, Usage}, + common::{BridgeContext, Usage}, embed::{EmbedRequestBody, EmbedResponseBody, EmbeddingRequest, EmbeddingResponse}, - openai::{ - ChatCompletionRequest, ChatCompletionResponse, ChatMessage, - responses::ConversationReference, - }, + openai::{ChatCompletionRequest, ChatCompletionResponse}, response::ChatResponse, }, }; @@ -34,7 +30,6 @@ enum HttpResponseBody { pub struct Gateway { registry: ProviderRegistry, http_client: reqwest::Client, - session_store: Option>, } impl Gateway { @@ -51,17 +46,9 @@ impl Gateway { Self { registry, http_client, - session_store: None, } } - /// Enables server-side session state for formats that need it. - #[allow(dead_code)] - pub fn with_session_store(mut self, store: Arc) -> Self { - self.session_store = Some(store); - self - } - /// Returns the immutable provider registry backing this gateway. pub fn registry(&self) -> &ProviderRegistry { &self.registry @@ -82,12 +69,19 @@ impl Gateway { .await; } - let (mut hub_request, ctx) = F::to_hub(request)?; - if F::name() == ResponsesApiFormat::name() { - self.resolve_session(&mut hub_request, &ctx).await?; - } + let (hub_request, ctx) = F::to_hub(request)?; + self.chat_from_hub::(hub_request, ctx, instance).await + } - if stream { + /// Executes a bridged hub request without any gateway-owned session state. + #[fastrace::trace] + async fn chat_from_hub( + &self, + hub_request: ChatCompletionRequest, + ctx: BridgeContext, + instance: &ProviderInstance, + ) -> Result> { + if hub_request.stream.unwrap_or(false) { let hub_stream = self.call_chat_hub_stream(instance, &hub_request).await?; let (usage_tx, usage_rx) = oneshot::channel(); let bridged_stream = BridgedStream::::new(hub_stream, ctx, usage_tx); @@ -99,9 +93,6 @@ impl Gateway { } let hub_response = self.call_chat_hub(instance, &hub_request).await?; - if F::name() == ResponsesApiFormat::name() { - self.save_session(&ctx, &hub_request, &hub_response).await?; - } let usage = extract_chat_usage_from_response(&hub_response).unwrap_or_default(); let response = F::from_hub(&hub_response, &ctx)?; @@ -179,61 +170,6 @@ impl Gateway { transform.transform_embeddings_response(response_body) } - async fn resolve_session( - &self, - hub_request: &mut ChatCompletionRequest, - ctx: &BridgeContext, - ) -> Result<()> { - let Some(store) = self.session_store.as_deref() else { - return Ok(()); - }; - let Some(history) = load_responses_session_messages(store, ctx).await? else { - return Ok(()); - }; - - let insert_at = usize::from(responses_instructions_present(ctx)); - hub_request.messages.splice(insert_at..insert_at, history); - Ok(()) - } - - async fn save_session( - &self, - ctx: &BridgeContext, - hub_request: &ChatCompletionRequest, - hub_response: &ChatCompletionResponse, - ) -> Result<()> { - let Some(store) = self.session_store.as_deref() else { - return Ok(()); - }; - if matches!( - ctx.openai_responses_extras - .as_ref() - .and_then(|extras| extras.store), - Some(false) - ) { - return Ok(()); - } - - let mut messages = session_messages_from_hub_request(hub_request, ctx); - messages.extend( - hub_response - .choices - .iter() - .map(|choice| choice.message.clone()), - ); - - let session = StoredSession { - response_id: hub_response.id.clone(), - conversation_id: responses_conversation_id(ctx.openai_responses_extras.as_ref()), - messages, - model: hub_response.model.clone(), - created_at: hub_response.created, - insertion_index: 0, - metadata: responses_metadata(ctx.openai_responses_extras.as_ref()), - }; - store.put_session(&session).await - } - fn prepare_json_request( &self, instance: &ProviderInstance, @@ -381,64 +317,6 @@ impl Gateway { } } -async fn load_responses_session_messages( - store: &dyn SessionStore, - ctx: &BridgeContext, -) -> Result>> { - let extras = ctx.openai_responses_extras.as_ref(); - - if let Some(previous_response_id) = - extras.and_then(|extras| extras.previous_response_id.as_deref()) - { - let session = store - .get_by_response_id(previous_response_id) - .await? - .ok_or_else(|| { - GatewayError::Validation(format!( - "previous_response_not_found: {}", - previous_response_id - )) - })?; - return Ok(Some(session.messages)); - } - - let Some(conversation_id) = responses_conversation_id(extras) else { - return Ok(None); - }; - let sessions = store.get_by_conversation_id(&conversation_id).await?; - Ok(sessions.last().map(|session| session.messages.clone())) -} - -fn session_messages_from_hub_request( - hub_request: &ChatCompletionRequest, - ctx: &BridgeContext, -) -> Vec { - let skip = usize::from(responses_instructions_present(ctx)); - hub_request.messages.iter().skip(skip).cloned().collect() -} - -fn responses_instructions_present(ctx: &BridgeContext) -> bool { - ctx.openai_responses_extras - .as_ref() - .and_then(|extras| extras.instructions.as_ref()) - .is_some_and(|instructions| !instructions.is_empty()) -} - -fn responses_conversation_id(extras: Option<&OpenAIResponsesExtras>) -> Option { - match extras?.conversation.as_ref()? { - ConversationReference::Id(id) => Some(id.clone()), - ConversationReference::Descriptor { id } => Some(id.clone()), - } -} - -fn responses_metadata(extras: Option<&OpenAIResponsesExtras>) -> HashMap { - extras - .and_then(|extras| extras.metadata.as_ref()) - .and_then(Value::as_object) - .map(|metadata| metadata.clone().into_iter().collect()) - .unwrap_or_default() -} - fn ensure_chat_stream_reader_supported(kind: StreamReaderKind) -> Result<()> { match kind { StreamReaderKind::Sse | StreamReaderKind::AwsEventStream => Ok(()), @@ -560,12 +438,10 @@ mod tests { use super::Gateway; use crate::{ error::{GatewayError, Result}, - formats::ResponsesApiFormat, provider_instance::{ AwsStaticCredentials, ProviderAuth, ProviderInstance, ProviderRegistry, }, providers::{AnthropicDef, BedrockDef}, - session::InMemorySessionStore, traits::{ ChatFormat, ChatTransform, EmbedTransform, NativeHandler, NativeOpenAIResponsesSupport, OpenAIResponsesNativeStreamState, PreparedRequest, ProviderCapabilities, ProviderMeta, @@ -573,7 +449,7 @@ mod tests { }, types::{ anthropic::{AnthropicContentBlock, AnthropicMessagesRequest}, - common::{BridgeContext, OpenAIResponsesExtras, Usage}, + common::{BridgeContext, Usage}, embed::EmbeddingRequest, openai::{ ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse, @@ -1062,175 +938,6 @@ mod tests { server.abort(); } - #[tokio::test] - async fn responses_requests_save_and_resolve_previous_response_id_sessions() { - let observed: Arc>> = Arc::new(Mutex::new(Vec::new())); - let observed_clone = Arc::clone(&observed); - let router = Router::new().route( - "/v1/chat/completions", - post(move |Json(body): Json| { - let observed = Arc::clone(&observed_clone); - async move { - let mut observed = observed.lock().await; - observed.push(body); - let request_number = observed.len(); - let reply = if request_number == 1 { - "first reply" - } else { - "second reply" - }; - - Json(json!({ - "id": format!("chatcmpl-{}", request_number), - "object": "chat.completion", - "created": request_number, - "model": "gpt-test", - "choices": [{ - "index": 0, - "message": { - "role": "assistant", - "content": reply - }, - "finish_reason": "stop" - }] - })) - } - }), - ); - let (base_url, server) = spawn_server(router).await; - - let gateway = Gateway::new(ProviderRegistry::builder().build()) - .with_session_store(Arc::new(InMemorySessionStore::default())); - let instance = ProviderInstance { - def: Arc::new(HubTestProvider), - auth: ProviderAuth::ApiKey("hub-secret".into()), - base_url_override: Some(base_url), - custom_headers: HeaderMap::new(), - }; - - let first_request: ResponsesApiRequest = serde_json::from_value(json!({ - "model": "gpt-test", - "input": "hello" - })) - .unwrap(); - let first_response = gateway - .chat::(&first_request, &instance) - .await - .unwrap(); - let ChatResponse::Complete { - response: first_response, - .. - } = first_response - else { - panic!("expected complete responses response") - }; - assert_eq!(first_response.id, "chatcmpl-1"); - - let second_request: ResponsesApiRequest = serde_json::from_value(json!({ - "model": "gpt-test", - "input": "how are you?", - "previous_response_id": "chatcmpl-1" - })) - .unwrap(); - let second_response = gateway - .chat::(&second_request, &instance) - .await - .unwrap(); - let ChatResponse::Complete { - response: second_response, - .. - } = second_response - else { - panic!("expected complete responses response") - }; - assert_eq!(second_response.id, "chatcmpl-2"); - - let observed = observed.lock().await; - assert_eq!(observed.len(), 2); - assert_eq!(observed[0]["messages"][0]["content"], "hello"); - assert_eq!(observed[1]["messages"][0]["role"], "user"); - assert_eq!(observed[1]["messages"][0]["content"], "hello"); - assert_eq!(observed[1]["messages"][1]["role"], "assistant"); - assert_eq!(observed[1]["messages"][1]["content"], "first reply"); - assert_eq!(observed[1]["messages"][2]["role"], "user"); - assert_eq!(observed[1]["messages"][2]["content"], "how are you?"); - - server.abort(); - } - - #[test] - fn session_messages_from_hub_request_drops_synthetic_responses_instructions() { - let hub_request: ChatCompletionRequest = serde_json::from_value(json!({ - "model": "gpt-test", - "messages": [ - {"role": "system", "content": "be terse"}, - {"role": "user", "content": "hello"} - ] - })) - .unwrap(); - let ctx = BridgeContext { - openai_responses_extras: Some(OpenAIResponsesExtras { - previous_response_id: None, - instructions: Some("be terse".into()), - store: None, - metadata: None, - background: None, - context_management: None, - conversation: None, - include: None, - max_tool_calls: None, - prompt: None, - prompt_cache_key: None, - prompt_cache_retention: None, - reasoning: None, - safety_identifier: None, - service_tier: None, - stream_options: None, - text: None, - top_logprobs: None, - truncation: None, - }), - ..Default::default() - }; - - let session_messages = super::session_messages_from_hub_request(&hub_request, &ctx); - - assert_eq!(session_messages.len(), 1); - assert_eq!(session_messages[0].role, "user"); - assert_matches!( - session_messages[0].content.as_ref(), - Some(crate::types::openai::MessageContent::Text(text)) if text == "hello" - ); - } - - #[tokio::test] - async fn responses_previous_response_id_not_found_returns_validation() { - let gateway = Gateway::new(ProviderRegistry::builder().build()) - .with_session_store(Arc::new(InMemorySessionStore::default())); - let instance = ProviderInstance { - def: Arc::new(HubTestProvider), - auth: ProviderAuth::ApiKey("hub-secret".into()), - base_url_override: Some(Url::parse("http://127.0.0.1:1").unwrap()), - custom_headers: HeaderMap::new(), - }; - let request: ResponsesApiRequest = serde_json::from_value(json!({ - "model": "gpt-test", - "input": "hello", - "previous_response_id": "resp_missing" - })) - .unwrap(); - - let result = gateway - .chat::(&request, &instance) - .await; - - assert_matches!( - result.map(|_| ()), - Err(GatewayError::Validation(message)) - if message.contains("previous_response_not_found") - ); - } - #[tokio::test] async fn chat_completion_applies_provider_prepare_request() { let observed: Arc> = Arc::new(Mutex::new(None)); diff --git a/crates/aisix-llm/src/lib.rs b/crates/aisix-llm/src/lib.rs index d5f87eb..5d5aaa2 100644 --- a/crates/aisix-llm/src/lib.rs +++ b/crates/aisix-llm/src/lib.rs @@ -3,7 +3,6 @@ pub mod formats; pub mod gateway; pub mod provider_instance; pub mod providers; -pub mod session; pub mod streams; pub mod traits; pub mod types; diff --git a/crates/aisix-llm/src/session.rs b/crates/aisix-llm/src/session.rs deleted file mode 100644 index eb8db3c..0000000 --- a/crates/aisix-llm/src/session.rs +++ /dev/null @@ -1,188 +0,0 @@ -use std::{ - collections::HashMap, - sync::atomic::{AtomicU64, Ordering}, -}; - -use async_trait::async_trait; -use serde_json::Value; -use tokio::sync::RwLock; - -use crate::{error::Result, types::openai::ChatMessage}; - -#[allow(dead_code)] -#[async_trait] -pub trait SessionStore: Send + Sync + 'static { - async fn get_by_response_id(&self, response_id: &str) -> Result>; - async fn get_by_conversation_id(&self, conv_id: &str) -> Result>; - async fn put_session(&self, session: &StoredSession) -> Result<()>; - async fn delete_session(&self, response_id: &str) -> Result<()>; -} - -#[allow(dead_code)] -#[derive(Debug, Clone, Default)] -pub struct StoredSession { - pub response_id: String, - pub conversation_id: Option, - pub messages: Vec, - pub model: String, - pub created_at: u64, - pub insertion_index: u64, - pub metadata: HashMap, -} - -#[allow(dead_code)] -#[derive(Debug, Default)] -pub struct InMemorySessionStore { - sessions: RwLock>, - next_insertion_index: AtomicU64, -} - -#[async_trait] -impl SessionStore for InMemorySessionStore { - async fn get_by_response_id(&self, response_id: &str) -> Result> { - Ok(self.sessions.read().await.get(response_id).cloned()) - } - - async fn get_by_conversation_id(&self, conv_id: &str) -> Result> { - let mut sessions: Vec<_> = self - .sessions - .read() - .await - .values() - .filter(|session| session.conversation_id.as_deref() == Some(conv_id)) - .cloned() - .collect(); - sessions.sort_by_key(|session| (session.created_at, session.insertion_index)); - Ok(sessions) - } - - async fn put_session(&self, session: &StoredSession) -> Result<()> { - let mut stored = session.clone(); - let mut sessions = self.sessions.write().await; - stored.insertion_index = sessions - .get(&stored.response_id) - .map(|existing| existing.insertion_index) - .unwrap_or_else(|| self.next_insertion_index.fetch_add(1, Ordering::Relaxed)); - sessions.insert(stored.response_id.clone(), stored); - Ok(()) - } - - async fn delete_session(&self, response_id: &str) -> Result<()> { - self.sessions.write().await.remove(response_id); - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use std::collections::HashMap; - - use pretty_assertions::assert_eq; - use serde_json::json; - - use super::{InMemorySessionStore, SessionStore, StoredSession}; - use crate::types::openai::{ChatMessage, MessageContent}; - - fn sample_message(text: &str) -> ChatMessage { - ChatMessage { - role: "user".into(), - content: Some(MessageContent::Text(text.into())), - name: None, - tool_calls: None, - tool_call_id: None, - } - } - - #[tokio::test] - async fn in_memory_session_store_crud_round_trip() { - let store = InMemorySessionStore::default(); - let session = StoredSession { - response_id: "resp_1".into(), - conversation_id: Some("conv_1".into()), - messages: vec![sample_message("hello")], - model: "gpt-test".into(), - created_at: 10, - insertion_index: 0, - metadata: HashMap::from([("trace".into(), json!("abc"))]), - }; - - store.put_session(&session).await.unwrap(); - - let loaded = store.get_by_response_id("resp_1").await.unwrap().unwrap(); - assert_eq!(loaded.response_id, "resp_1"); - assert_eq!(loaded.conversation_id.as_deref(), Some("conv_1")); - assert_eq!(loaded.messages.len(), 1); - assert_eq!(loaded.model, "gpt-test"); - assert_eq!(loaded.metadata.get("trace"), Some(&json!("abc"))); - - let by_conversation = store.get_by_conversation_id("conv_1").await.unwrap(); - assert_eq!(by_conversation.len(), 1); - assert_eq!(by_conversation[0].response_id, "resp_1"); - - store.delete_session("resp_1").await.unwrap(); - assert!(store.get_by_response_id("resp_1").await.unwrap().is_none()); - } - - #[tokio::test] - async fn in_memory_session_store_returns_conversation_sessions_in_created_order() { - let store = InMemorySessionStore::default(); - let newer = StoredSession { - response_id: "resp_2".into(), - conversation_id: Some("conv_1".into()), - messages: vec![sample_message("newer")], - model: "gpt-test".into(), - created_at: 20, - insertion_index: 0, - metadata: HashMap::new(), - }; - let older = StoredSession { - response_id: "resp_1".into(), - conversation_id: Some("conv_1".into()), - messages: vec![sample_message("older")], - model: "gpt-test".into(), - created_at: 10, - insertion_index: 0, - metadata: HashMap::new(), - }; - - store.put_session(&newer).await.unwrap(); - store.put_session(&older).await.unwrap(); - - let sessions = store.get_by_conversation_id("conv_1").await.unwrap(); - assert_eq!(sessions.len(), 2); - assert_eq!(sessions[0].response_id, "resp_1"); - assert_eq!(sessions[1].response_id, "resp_2"); - } - - #[tokio::test] - async fn in_memory_session_store_breaks_created_at_ties_by_insertion_order() { - let store = InMemorySessionStore::default(); - let first = StoredSession { - response_id: "resp_1".into(), - conversation_id: Some("conv_1".into()), - messages: vec![sample_message("first")], - model: "gpt-test".into(), - created_at: 10, - insertion_index: 0, - metadata: HashMap::new(), - }; - let second = StoredSession { - response_id: "resp_2".into(), - conversation_id: Some("conv_1".into()), - messages: vec![sample_message("second")], - model: "gpt-test".into(), - created_at: 10, - insertion_index: 0, - metadata: HashMap::new(), - }; - - store.put_session(&first).await.unwrap(); - store.put_session(&second).await.unwrap(); - - let sessions = store.get_by_conversation_id("conv_1").await.unwrap(); - assert_eq!(sessions.len(), 2); - assert_eq!(sessions[0].response_id, "resp_1"); - assert_eq!(sessions[1].response_id, "resp_2"); - assert!(sessions[0].insertion_index < sessions[1].insertion_index); - } -} diff --git a/crates/aisix-llm/src/types/openai/responses.rs b/crates/aisix-llm/src/types/openai/responses.rs index 5194ebf..021fdcc 100644 --- a/crates/aisix-llm/src/types/openai/responses.rs +++ b/crates/aisix-llm/src/types/openai/responses.rs @@ -14,6 +14,8 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; +use super::ChatMessage; + // ── Request types ── /// Responses API request. @@ -100,6 +102,10 @@ pub struct ResponsesApiRequest { #[serde(skip_serializing_if = "Option::is_none")] pub truncation: Option, + + /// Internal replay history injected by the proxy before bridge conversion. + #[serde(skip)] + pub replay_messages: Vec, } /// Context management configuration for Responses requests. diff --git a/src/lib.rs b/src/lib.rs index f360d37..9f7f7ab 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -86,19 +86,19 @@ pub async fn run_with_provider( ) -> Result<()> { let resources = Arc::new(config::entities::ResourceRegistry::new(config_provider.clone()).await); + let message_history_storage: Arc = + Arc::new(proxy::message_history::InMemoryMessageHistoryStorage::default()); - let gateway = Arc::new( - gateway::Gateway::new( - gateway::providers::default_provider_registry() - .context("failed to build default gateway provider registry")?, - ) - .with_session_store(Arc::new(gateway::session::InMemorySessionStore::default())), - ); + let gateway = Arc::new(gateway::Gateway::new( + gateway::providers::default_provider_registry() + .context("failed to build default gateway provider registry")?, + )); let proxy_router = proxy::create_router(proxy::AppState::new( config.clone(), resources.clone(), gateway, + message_history_storage, )); let res = select! { diff --git a/src/proxy/handlers/chat_completions/mod.rs b/src/proxy/handlers/chat_completions/mod.rs index ab42886..487a18a 100644 --- a/src/proxy/handlers/chat_completions/mod.rs +++ b/src/proxy/handlers/chat_completions/mod.rs @@ -31,6 +31,7 @@ impl FormatHandlerAdapter for ChatCompletionsAdapter { type StreamChunk = ChatCompletionChunk; type Error = ChatCompletionError; type Collector = StreamOutputCollector; + type LifecycleState = (); fn span_name() -> &'static str { "aisix.llm.chat_completions" diff --git a/src/proxy/handlers/chat_completions/types.rs b/src/proxy/handlers/chat_completions/types.rs index e439c85..188db1a 100644 --- a/src/proxy/handlers/chat_completions/types.rs +++ b/src/proxy/handlers/chat_completions/types.rs @@ -1,14 +1,15 @@ -use axum::{ - Json, - response::{IntoResponse, Response}, -}; -use http::StatusCode; +use axum::response::{IntoResponse, Response}; use thiserror::Error; use tokio::time::error::Elapsed; use crate::{ gateway::error::GatewayError, - proxy::hooks::{authorization::AuthorizationError, rate_limit::RateLimitError}, + proxy::{ + handlers::openai_error::{ + gateway_error_response, missing_model_response, timeout_response, + }, + hooks::{authorization::AuthorizationError, rate_limit::RateLimitError}, + }, }; #[derive(Debug, Error)] @@ -31,65 +32,12 @@ impl IntoResponse for ChatCompletionError { ChatCompletionError::AuthorizationError(err) => err.into_response(), ChatCompletionError::RateLimitError(RateLimitError::Raw(resp)) => resp, ChatCompletionError::GatewayError(err) => { - let status = err.status_code(); - let (message, error_type, code) = match err { - GatewayError::Provider { .. } - | GatewayError::Http(_) - | GatewayError::Stream(_) => ( - "Provider error".to_string(), - "server_error", - "provider_error", - ), - GatewayError::Internal(_) => ( - "Gateway internal error".to_string(), - "server_error", - "internal_error", - ), - _ => ( - err.to_string(), - if status.is_client_error() { - "invalid_request_error" - } else { - "server_error" - }, - "gateway_error", - ), - }; - - ( - status, - Json(serde_json::json!({ - "error": { - "message": message, - "type": error_type, - "code": code - } - })), - ) - .into_response() + gateway_error_response("Chat Completions", &err, err.status_code()) + } + ChatCompletionError::Timeout(_) => timeout_response("Chat Completions"), + ChatCompletionError::MissingModelInContext => { + missing_model_response("Chat Completions") } - ChatCompletionError::Timeout(_) => ( - StatusCode::GATEWAY_TIMEOUT, - Json(serde_json::json!({ - "error": { - "message": "Provider request timed out", - "type": "server_error", - "code": "request_timeout" - } - })), - ) - .into_response(), - ChatCompletionError::MissingModelInContext => ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": { - "message": "model missing in request context", - "type": "server_error", - "code": "internal_error" - } - })), - ) - .into_response(), } } } diff --git a/src/proxy/handlers/embeddings/types.rs b/src/proxy/handlers/embeddings/types.rs index 4d7f813..2b976f8 100644 --- a/src/proxy/handlers/embeddings/types.rs +++ b/src/proxy/handlers/embeddings/types.rs @@ -1,11 +1,18 @@ -use axum::{Json, response::IntoResponse}; +use axum::{ + response::{IntoResponse, Response}, +}; use http::StatusCode; use thiserror::Error; use tokio::time::error::Elapsed; use crate::{ gateway::error::GatewayError, - proxy::hooks::{authorization::AuthorizationError, rate_limit::RateLimitError}, + proxy::{ + handlers::openai_error::{ + gateway_error_response, missing_model_response, timeout_response, + }, + hooks::{authorization::AuthorizationError, rate_limit::RateLimitError}, + }, }; #[derive(Debug, Error)] @@ -23,7 +30,7 @@ pub enum EmbeddingError { } impl IntoResponse for EmbeddingError { - fn into_response(self) -> axum::response::Response { + fn into_response(self) -> Response { match self { EmbeddingError::AuthorizationError(err) => err.into_response(), EmbeddingError::RateLimitError(RateLimitError::Raw(resp)) => resp, @@ -35,69 +42,10 @@ impl IntoResponse for EmbeddingError { GatewayError::EmbeddingsNotSupported { .. } => StatusCode::NOT_IMPLEMENTED, _ => err.status_code(), }; - let (message, error_type, code) = match err { - GatewayError::Provider { .. } - | GatewayError::Http(_) - | GatewayError::Stream(_) => ( - "Provider error".to_string(), - "server_error", - "provider_error", - ), - GatewayError::EmbeddingsNotSupported { .. } => ( - "Provider error".to_string(), - "server_error", - "provider_error", - ), - GatewayError::Internal(_) => ( - "Gateway internal error".to_string(), - "server_error", - "internal_error", - ), - _ => ( - err.to_string(), - if status.is_client_error() { - "invalid_request_error" - } else { - "server_error" - }, - "gateway_error", - ), - }; - - ( - status, - Json(serde_json::json!({ - "error": { - "message": message, - "type": error_type, - "code": code - } - })), - ) - .into_response() + gateway_error_response("Embeddings", &err, status) } - EmbeddingError::Timeout(_) => ( - StatusCode::GATEWAY_TIMEOUT, - Json(serde_json::json!({ - "error": { - "message": "Provider request timed out", - "type": "server_error", - "code": "request_timeout" - } - })), - ) - .into_response(), - EmbeddingError::MissingModelInContext => ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": { - "message": "model missing in request context", - "type": "server_error", - "code": "internal_error" - } - })), - ) - .into_response(), + EmbeddingError::Timeout(_) => timeout_response("Embeddings"), + EmbeddingError::MissingModelInContext => missing_model_response("Embeddings"), } } } diff --git a/src/proxy/handlers/format_handler.rs b/src/proxy/handlers/format_handler.rs index d977a69..81e228d 100644 --- a/src/proxy/handlers/format_handler.rs +++ b/src/proxy/handlers/format_handler.rs @@ -1,5 +1,6 @@ use std::{convert::Infallible, time::Duration}; +use async_trait::async_trait; use axum::{ Json, extract::State, @@ -43,21 +44,24 @@ type AdapterRequest = ::Request; type AdapterResponse = ::Response; type AdapterCollector = ::Collector; +#[async_trait] pub(crate) trait FormatHandlerAdapter: Send + Sync + 'static { type Format: ChatFormat< Request = Self::Request, Response = Self::Response, StreamChunk = Self::StreamChunk, >; - type Request; + type Request: Sync; type Response: Serialize; type StreamChunk: Serialize + Send + 'static; type Error: IntoResponse + + std::fmt::Display + From + From + From + From; type Collector: Default + Send + 'static; + type LifecycleState: Default + Send + 'static; fn span_name() -> &'static str; @@ -81,6 +85,49 @@ pub(crate) trait FormatHandlerAdapter: Send + Sync + 'static { fn output_message_span_properties(collector: &Self::Collector) -> Vec<(String, String)>; + async fn prepare_lifecycle( + _state: &AppState, + _request_ctx: &mut RequestContext, + _request: &mut Self::Request, + ) -> Result { + Ok(Self::LifecycleState::default()) + } + + async fn handle_complete_response( + _state: &AppState, + _request_ctx: &mut RequestContext, + _lifecycle_state: &mut Self::LifecycleState, + _response: &mut Self::Response, + _usage: &Usage, + ) -> Result<(), Self::Error> { + Ok(()) + } + + fn handle_stream_item( + _state: &AppState, + _request_ctx: &mut RequestContext, + _lifecycle_state: &mut Self::LifecycleState, + _chunk: &mut Self::StreamChunk, + ) { + } + + async fn handle_stream_success( + _state: &AppState, + _request_ctx: &mut RequestContext, + _lifecycle_state: Self::LifecycleState, + _usage: Option<&Usage>, + ) -> Result<(), Self::Error> { + Ok(()) + } + + async fn handle_stream_failure( + _state: &AppState, + _request_ctx: &mut RequestContext, + _lifecycle_state: Self::LifecycleState, + ) -> Result<(), Self::Error> { + Ok(()) + } + fn serialize_stream_item(chunk: &Self::StreamChunk) -> SseEvent { let mut event = SseEvent::default().data(::serialize_chunk_payload(chunk)); @@ -96,6 +143,10 @@ pub(crate) trait FormatHandlerAdapter: Send + Sync + 'static { None } + fn lifecycle_error_event(_error: &Self::Error) -> Option { + None + } + fn end_of_stream_event(_saw_item: bool) -> Option { None } @@ -134,6 +185,8 @@ where })?; let provider_instance = create_provider_instance(gateway.as_ref(), &provider)?; let provider_base_url = provider_instance.effective_base_url().ok(); + let mut lifecycle_state = + A::prepare_lifecycle(&state, &mut request_ctx, &mut request_data).await?; let span = Span::enter_with_local_parent(A::span_name()); apply_span_properties( @@ -146,25 +199,46 @@ where ); let (response, span) = (WithSpan { - inner: maybe_timeout( - timeout, - gateway.chat::>(&request_data, &provider_instance), - ), + inner: maybe_timeout(timeout, async { + Ok(state + .gateway() + .chat::>(&request_data, &provider_instance) + .await?) + }), span: Some(span), }) .await; match response { - Ok(Ok(ChatResponse::Complete { response, usage })) => { + Ok(Ok(ChatResponse::Complete { + mut response, + usage, + })) => { + A::handle_complete_response( + &state, + &mut request_ctx, + &mut lifecycle_state, + &mut response, + &usage, + ) + .await?; span.add_properties(|| A::response_span_properties(&response, &usage)); handle_regular_response::(response, usage, &mut request_ctx).await } Ok(Ok(ChatResponse::Stream { stream, usage_rx })) => { - handle_stream_response::(stream, usage_rx, &mut request_ctx, span).await + handle_stream_response::( + state, + stream, + usage_rx, + &mut request_ctx, + span, + lifecycle_state, + ) + .await } Ok(Err(err)) => { span.add_property(|| ("error.type", "gateway_error")); - Err(err.into()) + Err(err) } Err(err) => { span.add_property(|| ("error.type", "timeout")); @@ -212,14 +286,11 @@ fn spawn_stream_usage_observer(request_ctx: RequestContext, usage_rx: oneshot::R }); } -async fn finalize_stream_usage( +async fn finalize_stream_usage_observation( request_ctx: &mut RequestContext, usage_rx: &mut Option>, span: &Span, - output_message_properties: Vec<(String, String)>, ) { - span.add_properties(|| output_message_properties); - if let Some(mut usage_rx) = usage_rx.take() { match usage_rx.try_recv() { Ok(usage) => { @@ -240,11 +311,72 @@ async fn finalize_stream_usage( } } +async fn finalize_stream_success( + state: &AppState, + request_ctx: &mut RequestContext, + usage_rx: &mut Option>, + span: &Span, + output_message_properties: Vec<(String, String)>, + lifecycle_state: &mut Option, +) -> Result<(), A::Error> +where + A: FormatHandlerAdapter, +{ + span.add_properties(|| output_message_properties); + + let Some(lifecycle_state) = lifecycle_state.take() else { + finalize_stream_usage_observation(request_ctx, usage_rx, span).await; + return Ok(()); + }; + + if let Some(mut usage_rx) = usage_rx.take() { + match usage_rx.try_recv() { + Ok(usage) => { + if let Err(err) = hooks::rate_limit::post_check_streaming(request_ctx, &usage).await + { + error!("Rate limit post_check_streaming error: {}", err); + } + hooks::observability::record_streaming_usage(request_ctx, &usage).await; + span.add_properties(|| usage_span_properties(&usage)); + + A::handle_stream_success(state, request_ctx, lifecycle_state, Some(&usage)).await?; + } + Err(TryRecvError::Empty) => match usage_rx.await { + Ok(usage) => { + if let Err(err) = + hooks::rate_limit::post_check_streaming(request_ctx, &usage).await + { + error!("Rate limit post_check_streaming error: {}", err); + } + hooks::observability::record_streaming_usage(request_ctx, &usage).await; + span.add_properties(|| usage_span_properties(&usage)); + + A::handle_stream_success(state, request_ctx, lifecycle_state, Some(&usage)) + .await?; + } + Err(err) => { + error!("Failed to receive streaming usage from gateway: {}", err); + A::handle_stream_success(state, request_ctx, lifecycle_state, None).await?; + } + }, + Err(TryRecvError::Closed) => { + error!("Failed to receive streaming usage from gateway: channel closed"); + A::handle_stream_success(state, request_ctx, lifecycle_state, None).await?; + } + } + return Ok(()); + } + + A::handle_stream_success(state, request_ctx, lifecycle_state, None).await +} + async fn handle_stream_response( + state: AppState, stream: ChatResponseStream>, usage_rx: oneshot::Receiver, request_ctx: &mut RequestContext, span: Span, + lifecycle_state: A::LifecycleState, ) -> Result where A: FormatHandlerAdapter, @@ -252,8 +384,10 @@ where use futures::stream::StreamExt; let stream_request_ctx = request_ctx.clone(); + let stream_state = state.clone(); let sse_stream = futures::stream::unfold( ( + stream_state, stream, span, stream_request_ctx, @@ -262,8 +396,10 @@ where Some(usage_rx), AdapterCollector::::default(), false, + Some(lifecycle_state), ), |( + state, mut stream, span, mut request_ctx, @@ -272,6 +408,7 @@ where mut usage_rx, mut output_collector, mut first_output_arrived, + mut lifecycle_state, )| async move { if should_terminate { drop(span); @@ -279,7 +416,16 @@ where } match stream.next().await { - Some(Ok(chunk)) => { + Some(Ok(mut chunk)) => { + if let Some(lifecycle_state) = lifecycle_state.as_mut() { + A::handle_stream_item( + &state, + &mut request_ctx, + lifecycle_state, + &mut chunk, + ); + } + A::record_stream_item(&mut output_collector, &chunk); let now_starts_output = !first_output_arrived && A::starts_output(&chunk); @@ -297,6 +443,7 @@ where Some(( Ok::(A::serialize_stream_item(&chunk)), ( + state, stream, span, request_ctx, @@ -305,24 +452,31 @@ where usage_rx, output_collector, first_output_arrived, + lifecycle_state, ), )) } Some(Err(err)) => { error!("Gateway stream error: {}", err); span.add_property(|| ("error.type", "stream_error")); - finalize_stream_usage( - &mut request_ctx, - &mut usage_rx, - &span, - A::output_message_span_properties(&output_collector), - ) - .await; + span.add_properties(|| A::output_message_span_properties(&output_collector)); + + if let Some(lifecycle_state) = lifecycle_state.take() { + if let Err(lifecycle_err) = + A::handle_stream_failure(&state, &mut request_ctx, lifecycle_state) + .await + { + error!("Stream failure lifecycle error: {}", lifecycle_err); + } + } + + finalize_stream_usage_observation(&mut request_ctx, &mut usage_rx, &span).await; if let Some(event) = A::stream_error_event(&err) { Some(( Ok(event), ( + state, stream, span, request_ctx, @@ -331,6 +485,7 @@ where usage_rx, output_collector, first_output_arrived, + lifecycle_state, ), )) } else { @@ -339,18 +494,49 @@ where } } None => { - finalize_stream_usage( + match finalize_stream_success::( + &state, &mut request_ctx, &mut usage_rx, &span, A::output_message_span_properties(&output_collector), + &mut lifecycle_state, ) - .await; + .await + { + Ok(()) => {} + Err(err) => { + error!("Stream success lifecycle error: {}", err); + span.add_property(|| ("error.type", "stream_success_lifecycle_error")); + + if let Some(event) = A::lifecycle_error_event(&err) { + return Some(( + Ok(event), + ( + state, + stream, + span, + request_ctx, + true, + saw_item, + usage_rx, + output_collector, + first_output_arrived, + lifecycle_state, + ), + )); + } + + drop(span); + return None; + } + } if let Some(event) = A::end_of_stream_event(saw_item) { Some(( Ok(event), ( + state, stream, span, request_ctx, @@ -359,6 +545,7 @@ where usage_rx, output_collector, first_output_arrived, + lifecycle_state, ), )) } else { diff --git a/src/proxy/handlers/messages/mod.rs b/src/proxy/handlers/messages/mod.rs index 23cbe05..3e00a52 100644 --- a/src/proxy/handlers/messages/mod.rs +++ b/src/proxy/handlers/messages/mod.rs @@ -61,6 +61,7 @@ impl FormatHandlerAdapter for MessagesAdapter { type StreamChunk = AnthropicStreamEvent; type Error = MessagesError; type Collector = StreamOutputCollector; + type LifecycleState = (); fn span_name() -> &'static str { "aisix.llm.messages" diff --git a/src/proxy/handlers/mod.rs b/src/proxy/handlers/mod.rs index be5eb96..4257741 100644 --- a/src/proxy/handlers/mod.rs +++ b/src/proxy/handlers/mod.rs @@ -3,4 +3,5 @@ pub mod embeddings; pub mod format_handler; pub mod messages; pub mod models; +pub(crate) mod openai_error; pub mod responses; diff --git a/src/proxy/handlers/openai_error.rs b/src/proxy/handlers/openai_error.rs new file mode 100644 index 0000000..f6cb13b --- /dev/null +++ b/src/proxy/handlers/openai_error.rs @@ -0,0 +1,146 @@ +use axum::{ + Json, + response::{IntoResponse, Response}, +}; +use http::StatusCode; +use log::error; +use serde_json::{Map, Value, json}; + +use crate::gateway::error::GatewayError; + +pub(crate) fn gateway_error_response( + handler_name: &'static str, + err: &GatewayError, + status: StatusCode, +) -> Response { + error!("{handler_name} gateway error: {err}"); + + match err { + GatewayError::Provider { body, .. } => { + provider_error_response(status, provider_error_message(body), Some(body.clone())) + } + GatewayError::Http(http_error) => provider_error_response( + status, + format!("Upstream HTTP error: {http_error}"), + Some(Value::String(http_error.to_string())), + ), + GatewayError::Stream(stream_error) => provider_error_response( + status, + format!("Upstream stream error: {stream_error}"), + Some(Value::String(stream_error.clone())), + ), + GatewayError::EmbeddingsNotSupported { .. } => { + provider_error_response(status, "Provider error".into(), None) + } + GatewayError::Internal(_) => internal_error_response(status), + _ => generic_gateway_error_response(status, err), + } +} + +pub(crate) fn timeout_response(handler_name: &'static str) -> Response { + error!("{handler_name} request timed out"); + + ( + StatusCode::GATEWAY_TIMEOUT, + Json(json!({ + "error": { + "message": "Provider request timed out", + "type": "server_error", + "code": "request_timeout" + } + })), + ) + .into_response() +} + +pub(crate) fn missing_model_response(handler_name: &'static str) -> Response { + error!("{handler_name} model missing in request context"); + + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ + "error": { + "message": "model missing in request context", + "type": "server_error", + "code": "internal_error" + } + })), + ) + .into_response() +} + +fn provider_error_response(status: StatusCode, message: String, details: Option) -> Response { + let mut error_object = Map::from_iter([ + ("message".to_string(), Value::String(message)), + ("type".to_string(), Value::String("server_error".into())), + ("code".to_string(), Value::String("provider_error".into())), + ]); + if let Some(details) = details.filter(|details| !details.is_null()) { + error_object.insert("details".to_string(), details); + } + + ( + status, + Json(json!({ + "error": Value::Object(error_object), + })), + ) + .into_response() +} + +fn internal_error_response(status: StatusCode) -> Response { + ( + status, + Json(json!({ + "error": { + "message": "Gateway internal error", + "type": "server_error", + "code": "internal_error" + } + })), + ) + .into_response() +} + +fn generic_gateway_error_response(status: StatusCode, err: &GatewayError) -> Response { + ( + status, + Json(json!({ + "error": { + "message": err.to_string(), + "type": if status.is_client_error() { + "invalid_request_error" + } else { + "server_error" + }, + "code": "gateway_error" + } + })), + ) + .into_response() +} + +fn provider_error_message(body: &Value) -> String { + provider_nested_message(body).unwrap_or_else(|| "Provider error".into()) +} + +fn provider_nested_message(body: &Value) -> Option { + match body { + Value::Object(map) => match map.get("error") { + Some(Value::Object(error_object)) => error_object.get("message").and_then(value_to_string), + Some(Value::String(text)) => Some(text.clone()), + _ => map.get("message").and_then(value_to_string), + }, + Value::String(text) if !text.is_empty() => Some(text.clone()), + _ => None, + } +} + +fn value_to_string(value: &Value) -> Option { + match value { + Value::String(text) => Some(text.clone()), + Value::Number(number) => Some(number.to_string()), + Value::Bool(flag) => Some(flag.to_string()), + _ => None, + } +} diff --git a/src/proxy/handlers/responses/mod.rs b/src/proxy/handlers/responses/mod.rs index b8d39a1..9f6f424 100644 --- a/src/proxy/handlers/responses/mod.rs +++ b/src/proxy/handlers/responses/mod.rs @@ -1,9 +1,16 @@ +mod runtime; mod span_attributes; mod types; +use async_trait::async_trait; use axum::response::sse::Event as SseEvent; use fastrace::Span; use reqwest::Url; +use runtime::{ + ResponsesLifecycleState, accumulate_complete, accumulate_stream_event, + accumulate_stream_success, build_merged_input_messages, init_lifecycle, load_previous_messages, + persist_if_enabled, +}; use span_attributes::{ StreamOutputCollector, chunk_span_properties, event_starts_output, request_span_properties, response_span_properties, @@ -19,10 +26,11 @@ use crate::{ common::Usage, openai::responses::{ ResponsesApiRequest, ResponsesApiResponse, ResponsesApiStreamEvent, + ResponsesOutputItem, }, }, }, - proxy::handlers::format_handler::FormatHandlerAdapter, + proxy::{AppState, handlers::format_handler::FormatHandlerAdapter, hooks::RequestContext}, }; fn serialize_stream_event(event: &ResponsesApiStreamEvent) -> SseEvent { @@ -38,6 +46,7 @@ fn serialize_stream_event(event: &ResponsesApiStreamEvent) -> SseEvent { pub(crate) struct ResponsesAdapter; +#[async_trait] impl FormatHandlerAdapter for ResponsesAdapter { type Format = ResponsesApiFormat; type Request = ResponsesApiRequest; @@ -45,6 +54,7 @@ impl FormatHandlerAdapter for ResponsesAdapter { type StreamChunk = ResponsesApiStreamEvent; type Error = ResponsesError; type Collector = StreamOutputCollector; + type LifecycleState = ResponsesLifecycleState; fn span_name() -> &'static str { "aisix.llm.responses" @@ -86,6 +96,71 @@ impl FormatHandlerAdapter for ResponsesAdapter { collector.output_message_span_properties() } + async fn prepare_lifecycle( + state: &AppState, + _request_ctx: &mut RequestContext, + request: &mut Self::Request, + ) -> Result { + let mut lifecycle_state = init_lifecycle(request); + let storage = state.message_history_storage(); + let previous_messages = + load_previous_messages(storage.as_ref(), request.previous_response_id.as_deref()) + .await?; + lifecycle_state.merged_input_messages = + build_merged_input_messages(request, &previous_messages)?; + request.replay_messages = previous_messages; + + Ok(lifecycle_state) + } + + async fn handle_complete_response( + state: &AppState, + _request_ctx: &mut RequestContext, + lifecycle_state: &mut Self::LifecycleState, + response: &mut Self::Response, + usage: &Usage, + ) -> Result<(), Self::Error> { + let stored_history = accumulate_complete( + lifecycle_state, + response, + usage.clone().with_derived_total(), + )?; + let storage = state.message_history_storage(); + persist_if_enabled(storage.as_ref(), lifecycle_state, &stored_history).await?; + rewrite_response_ids(response, &stored_history.response_id); + Ok(()) + } + + fn handle_stream_item( + _state: &AppState, + _request_ctx: &mut RequestContext, + lifecycle_state: &mut Self::LifecycleState, + chunk: &mut Self::StreamChunk, + ) { + accumulate_stream_event(lifecycle_state, chunk); + rewrite_stream_event_ids(chunk, &lifecycle_state.response_id); + } + + async fn handle_stream_success( + state: &AppState, + _request_ctx: &mut RequestContext, + lifecycle_state: Self::LifecycleState, + usage: Option<&Usage>, + ) -> Result<(), Self::Error> { + let stored_history = accumulate_stream_success(&lifecycle_state, usage)?; + let storage = state.message_history_storage(); + persist_if_enabled(storage.as_ref(), &lifecycle_state, &stored_history).await?; + Ok(()) + } + + async fn handle_stream_failure( + _state: &AppState, + _request_ctx: &mut RequestContext, + _lifecycle_state: Self::LifecycleState, + ) -> Result<(), Self::Error> { + Ok(()) + } + fn serialize_stream_item(chunk: &Self::StreamChunk) -> SseEvent { serialize_stream_event(chunk) } @@ -95,4 +170,38 @@ impl FormatHandlerAdapter for ResponsesAdapter { message: error.to_string(), })) } + + fn lifecycle_error_event(error: &Self::Error) -> Option { + Some(serialize_stream_event(&ResponsesApiStreamEvent::Error { + message: error.to_string(), + })) + } +} + +fn rewrite_response_ids(response: &mut ResponsesApiResponse, response_id: &str) { + response.id = response_id.to_owned(); + for (output_index, item) in response.output.iter_mut().enumerate() { + rewrite_output_item_ids(item, response_id, output_index); + } +} + +fn rewrite_stream_event_ids(event: &mut ResponsesApiStreamEvent, response_id: &str) { + match event { + ResponsesApiStreamEvent::ResponseCreated { response } + | ResponsesApiStreamEvent::ResponseInProgress { response } + | ResponsesApiStreamEvent::ResponseCompleted { response } => { + rewrite_response_ids(response, response_id); + } + ResponsesApiStreamEvent::OutputItemAdded { output_index, item } + | ResponsesApiStreamEvent::OutputItemDone { output_index, item } => { + rewrite_output_item_ids(item, response_id, *output_index); + } + _ => {} + } +} + +fn rewrite_output_item_ids(item: &mut ResponsesOutputItem, response_id: &str, output_index: usize) { + if let ResponsesOutputItem::Message { id, .. } = item { + *id = format!("{}_message_{}", response_id, output_index); + } } diff --git a/src/proxy/handlers/responses/runtime.rs b/src/proxy/handlers/responses/runtime.rs new file mode 100644 index 0000000..4d035d8 --- /dev/null +++ b/src/proxy/handlers/responses/runtime.rs @@ -0,0 +1,944 @@ +use std::collections::{BTreeMap, HashMap}; + +use serde_json::{Map, Value}; +use uuid::Uuid; + +use crate::{ + gateway::{ + error::{GatewayError, Result}, + types::{ + common::Usage, + openai::{ + ChatMessage, ContentPart, FunctionCall, MessageContent, ToolCall, + responses::{ + ResponsesApiRequest, ResponsesApiResponse, ResponsesApiStreamEvent, + ResponsesContent, ResponsesContentPart, ResponsesInput, ResponsesInputItem, + ResponsesOutputContent, ResponsesOutputItem, ResponsesUsage, + }, + }, + }, + }, + proxy::message_history::{ + MessageHistoryStorage, StoredMessageHistory, StoredMessageHistoryStatus, + }, +}; + +#[derive(Debug, Clone, Default)] +pub(crate) struct ResponsesLifecycleState { + pub response_id: String, + pub previous_response_id: Option, + pub merged_input_messages: Vec, + pub model: String, + pub metadata: HashMap, + pub store: bool, + pub accumulator: ResponsesStreamAccumulator, +} + +impl ResponsesLifecycleState { + fn request_metadata_value(&self) -> Option { + if self.metadata.is_empty() { + None + } else { + Some(Value::Object( + self.metadata + .iter() + .map(|(key, value)| (key.clone(), value.clone())) + .collect::>(), + )) + } + } +} + +#[derive(Debug, Clone, Default)] +pub(crate) struct ResponsesStreamAccumulator { + response_id: Option, + model: Option, + created_at: Option, + metadata: Option, + usage: Option, + output_items: BTreeMap, + failed: bool, +} + +impl ResponsesStreamAccumulator { + pub(crate) fn record_event(&mut self, event: &ResponsesApiStreamEvent) { + match event { + ResponsesApiStreamEvent::ResponseCreated { response } + | ResponsesApiStreamEvent::ResponseInProgress { response } + | ResponsesApiStreamEvent::ResponseCompleted { response } => { + self.record_response(response); + } + ResponsesApiStreamEvent::OutputItemAdded { output_index, item } + | ResponsesApiStreamEvent::OutputItemDone { output_index, item } => { + self.output_items.insert(*output_index, item.clone()); + } + ResponsesApiStreamEvent::ContentPartAdded { + output_index, + content_index, + part, + } + | ResponsesApiStreamEvent::ContentPartDone { + output_index, + content_index, + part, + } => { + if let Some(content_part) = + self.ensure_output_text_part(*output_index, *content_index) + { + *content_part = part.clone(); + } + } + ResponsesApiStreamEvent::OutputTextDelta { + output_index, + content_index, + delta, + } => { + if let Some(ResponsesOutputContent::OutputText { text }) = + self.ensure_output_text_part(*output_index, *content_index) + { + text.push_str(delta); + } + } + ResponsesApiStreamEvent::OutputTextDone { + output_index, + content_index, + text, + } => { + if let Some(ResponsesOutputContent::OutputText { text: current }) = + self.ensure_output_text_part(*output_index, *content_index) + { + *current = text.clone(); + } + } + ResponsesApiStreamEvent::FunctionCallArgumentsDelta { + output_index, + delta, + } => { + if let ResponsesOutputItem::FunctionCall { arguments, .. } = + self.ensure_function_call_item(*output_index) + { + arguments.push_str(delta); + } + } + ResponsesApiStreamEvent::FunctionCallArgumentsDone { + output_index, + arguments, + } => { + if let ResponsesOutputItem::FunctionCall { + arguments: current, .. + } = self.ensure_function_call_item(*output_index) + { + *current = arguments.clone(); + } + } + ResponsesApiStreamEvent::Error { .. } => { + self.failed = true; + } + } + } + + pub(crate) fn response_snapshot( + &self, + previous_response_id: Option, + fallback_model: &str, + fallback_metadata: Option, + fallback_usage: Option, + ) -> Result { + if self.failed { + return Err(GatewayError::Validation( + "responses stream ended in error; no completed snapshot available".into(), + )); + } + + let response_id = self.response_id.clone().ok_or_else(|| { + GatewayError::Validation( + "responses stream completed without an upstream response id".into(), + ) + })?; + + let usage = fallback_usage + .as_ref() + .map(responses_usage_from_common) + .or_else(|| self.usage.clone()) + .unwrap_or_default(); + + Ok(ResponsesApiResponse { + id: response_id, + object: "response".into(), + created_at: self.created_at.unwrap_or_default(), + model: self + .model + .clone() + .unwrap_or_else(|| fallback_model.to_owned()), + output: self.output_items.values().cloned().collect(), + status: "completed".into(), + usage, + metadata: self.metadata.clone().or(fallback_metadata), + previous_response_id, + }) + } + + fn record_response(&mut self, response: &ResponsesApiResponse) { + if self.response_id.is_none() && !response.id.is_empty() { + self.response_id = Some(response.id.clone()); + } + if self.model.is_none() && !response.model.is_empty() { + self.model = Some(response.model.clone()); + } + if self.created_at.is_none() { + self.created_at = Some(response.created_at); + } + if self.metadata.is_none() && response.metadata.is_some() { + self.metadata = response.metadata.clone(); + } + self.usage = Some(response.usage.clone()); + for (output_index, item) in response.output.iter().cloned().enumerate() { + self.output_items.insert(output_index, item); + } + } + + fn ensure_output_text_part( + &mut self, + output_index: usize, + content_index: usize, + ) -> Option<&mut ResponsesOutputContent> { + let item = self.ensure_message_item(output_index); + let ResponsesOutputItem::Message { content, .. } = item else { + return None; + }; + + if content.len() <= content_index { + content.resize_with(content_index + 1, || ResponsesOutputContent::OutputText { + text: String::new(), + }); + } + content.get_mut(content_index) + } + + fn ensure_message_item(&mut self, output_index: usize) -> &mut ResponsesOutputItem { + let response_id = self.response_id.as_deref().unwrap_or("response").to_owned(); + self.output_items + .entry(output_index) + .or_insert_with(|| ResponsesOutputItem::Message { + id: response_message_output_id(&response_id, output_index), + role: "assistant".into(), + content: vec![], + status: "in_progress".into(), + }) + } + + fn ensure_function_call_item(&mut self, output_index: usize) -> &mut ResponsesOutputItem { + let response_id = self.response_id.as_deref().unwrap_or("response").to_owned(); + self.output_items.entry(output_index).or_insert_with(|| { + let id = response_function_call_output_id(&response_id, output_index); + ResponsesOutputItem::FunctionCall { + id: id.clone(), + call_id: id, + name: format!("tool_{}", output_index), + arguments: String::new(), + status: "in_progress".into(), + } + }) + } +} + +pub(crate) fn init_lifecycle(request: &ResponsesApiRequest) -> ResponsesLifecycleState { + ResponsesLifecycleState { + response_id: generate_response_id(), + previous_response_id: request.previous_response_id.clone(), + merged_input_messages: vec![], + model: request.model.clone(), + metadata: request_metadata(request), + store: request.store != Some(false), + accumulator: ResponsesStreamAccumulator::default(), + } +} + +pub(crate) async fn load_previous_messages( + storage: &S, + previous_response_id: Option<&str>, +) -> Result> +where + S: MessageHistoryStorage + ?Sized, +{ + let Some(previous_response_id) = previous_response_id else { + return Ok(vec![]); + }; + + let history = storage + .get_by_response_id(previous_response_id) + .await? + .ok_or_else(|| { + GatewayError::Validation(format!( + "previous_response_not_found: {}", + previous_response_id + )) + })?; + Ok(history.cumulative_messages) +} + +pub(crate) fn build_merged_input_messages( + request: &ResponsesApiRequest, + previous_messages: &[ChatMessage], +) -> Result> { + let mut merged_input_messages = previous_messages.to_vec(); + merged_input_messages.extend(request_input_messages(request)?); + Ok(merged_input_messages) +} + +pub(crate) fn accumulate_stream_event( + state: &mut ResponsesLifecycleState, + event: &ResponsesApiStreamEvent, +) { + state.accumulator.record_event(event); +} + +pub(crate) fn accumulate_complete( + state: &ResponsesLifecycleState, + response: &ResponsesApiResponse, + usage: Usage, +) -> Result { + completed_history(state, response, usage) +} + +pub(crate) fn accumulate_stream_success( + state: &ResponsesLifecycleState, + usage: Option<&Usage>, +) -> Result { + let response = state.accumulator.response_snapshot( + state.previous_response_id.clone(), + &state.model, + state.request_metadata_value(), + usage.cloned(), + )?; + let usage = usage + .cloned() + .unwrap_or_else(|| responses_usage_to_common(&response.usage)) + .with_derived_total(); + completed_history(state, &response, usage) +} + +pub(crate) async fn persist_if_enabled( + storage: &S, + state: &ResponsesLifecycleState, + history: &StoredMessageHistory, +) -> Result<()> +where + S: MessageHistoryStorage + ?Sized, +{ + if state.store { + storage.put(history).await?; + } + Ok(()) +} + +fn completed_history( + state: &ResponsesLifecycleState, + response: &ResponsesApiResponse, + usage: Usage, +) -> Result { + let mut cumulative_messages = state.merged_input_messages.clone(); + cumulative_messages.extend(response_output_to_chat_messages(&response.output)); + + let mut metadata = state.metadata.clone(); + merge_metadata_value(&mut metadata, response.metadata.as_ref()); + + Ok(StoredMessageHistory { + response_id: state.response_id.clone(), + previous_response_id: state.previous_response_id.clone(), + upstream_response_id: Some(response.id.clone()), + cumulative_messages, + model: response.model.clone(), + created_at: response.created_at, + finished_at: Some(response.created_at), + usage: Some(usage), + status: StoredMessageHistoryStatus::Completed, + metadata, + }) +} + +fn generate_response_id() -> String { + format!("aresp_{}", Uuid::new_v4().simple()) +} + +fn request_metadata(request: &ResponsesApiRequest) -> HashMap { + request + .metadata + .as_ref() + .and_then(Value::as_object) + .map(|metadata| metadata.clone().into_iter().collect()) + .unwrap_or_default() +} + +fn request_input_messages(request: &ResponsesApiRequest) -> Result> { + match &request.input { + ResponsesInput::Text(text) => Ok(vec![ChatMessage { + role: "user".into(), + content: Some(MessageContent::Text(text.clone())), + name: None, + tool_calls: None, + tool_call_id: None, + }]), + ResponsesInput::Items(items) => items.iter().try_fold(vec![], |mut messages, item| { + if let Some(message) = request_input_item_to_chat_message(item)? { + messages.push(message); + } + Ok(messages) + }), + } +} + +fn request_input_item_to_chat_message(item: &ResponsesInputItem) -> Result> { + match item { + ResponsesInputItem::Message { role, content } => { + let Some(content) = request_content_to_message_content(content)? else { + return Ok(None); + }; + + Ok(Some(ChatMessage { + role: role.clone(), + content: Some(content), + name: None, + tool_calls: None, + tool_call_id: None, + })) + } + ResponsesInputItem::FunctionCallOutput { call_id, output } => Ok(Some(ChatMessage { + role: "tool".into(), + content: Some(MessageContent::Text(output.clone())), + name: None, + tool_calls: None, + tool_call_id: Some(call_id.clone()), + })), + } +} + +fn request_content_to_message_content( + content: &ResponsesContent, +) -> Result> { + match content { + ResponsesContent::Text(text) => Ok(Some(MessageContent::Text(text.clone()))), + ResponsesContent::Parts(parts) => { + let parts = parts + .iter() + .map(request_content_part_to_content_part) + .collect::>>()? + .into_iter() + .flatten() + .collect::>(); + + if parts.is_empty() { + Ok(None) + } else { + Ok(Some(MessageContent::Parts(parts))) + } + } + } +} + +fn request_content_part_to_content_part( + part: &ResponsesContentPart, +) -> Result> { + match part { + ResponsesContentPart::InputText { text } => { + Ok(Some(ContentPart::Text { text: text.clone() })) + } + ResponsesContentPart::InputImage { + image_url, detail, .. + } => Ok(image_url.as_ref().map(|url| ContentPart::ImageUrl { + image_url: crate::gateway::types::openai::ImageUrl { + url: url.clone(), + detail: detail.clone(), + }, + })), + } +} + +fn response_output_to_chat_messages(output: &[ResponsesOutputItem]) -> Vec { + let mut messages = Vec::new(); + let mut current_assistant_index = None; + + for item in output { + match item { + ResponsesOutputItem::Message { role, content, .. } => { + let message = ChatMessage { + role: role.clone(), + content: response_output_content_to_message_content(content), + name: None, + tool_calls: None, + tool_call_id: None, + }; + messages.push(message); + current_assistant_index = (role == "assistant").then_some(messages.len() - 1); + } + ResponsesOutputItem::FunctionCall { + id, + name, + arguments, + .. + } => { + let assistant_index = current_assistant_index.unwrap_or_else(|| { + messages.push(ChatMessage { + role: "assistant".into(), + content: None, + name: None, + tool_calls: Some(vec![]), + tool_call_id: None, + }); + let index = messages.len() - 1; + current_assistant_index = Some(index); + index + }); + + messages[assistant_index] + .tool_calls + .get_or_insert_with(Vec::new) + .push(ToolCall { + id: id.clone(), + r#type: "function".into(), + function: FunctionCall { + name: name.clone(), + arguments: arguments.clone(), + }, + }); + } + } + } + + messages + .into_iter() + .map(|mut message| { + if message + .tool_calls + .as_ref() + .is_some_and(|tool_calls| tool_calls.is_empty()) + { + message.tool_calls = None; + } + message + }) + .collect() +} + +fn response_output_content_to_message_content( + content: &[ResponsesOutputContent], +) -> Option { + match content { + [] => None, + [ResponsesOutputContent::OutputText { text }] => Some(MessageContent::Text(text.clone())), + multiple => Some(MessageContent::Parts( + multiple + .iter() + .map(|part| match part { + ResponsesOutputContent::OutputText { text } => { + ContentPart::Text { text: text.clone() } + } + }) + .collect(), + )), + } +} + +fn merge_metadata_value(metadata: &mut HashMap, extra: Option<&Value>) { + let Some(extra) = extra.and_then(Value::as_object) else { + return; + }; + + metadata.extend(extra.clone()); +} + +fn responses_usage_from_common(usage: &Usage) -> ResponsesUsage { + ResponsesUsage { + input_tokens: usage.input_tokens.unwrap_or_default(), + output_tokens: usage.output_tokens.unwrap_or_default(), + total_tokens: usage.resolved_total_tokens().unwrap_or_default(), + } +} + +fn responses_usage_to_common(usage: &ResponsesUsage) -> Usage { + Usage { + input_tokens: Some(usage.input_tokens), + output_tokens: Some(usage.output_tokens), + total_tokens: Some(usage.total_tokens), + ..Default::default() + } +} + +fn response_message_output_id(response_id: &str, output_index: usize) -> String { + format!("{}_message_{}", response_id, output_index) +} + +fn response_function_call_output_id(response_id: &str, output_index: usize) -> String { + format!("{}_call_{}", response_id, output_index) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use pretty_assertions::assert_eq; + use serde_json::json; + + use crate::{ + gateway::types::{ + common::Usage, + openai::{ + ChatMessage, MessageContent, + responses::{ + ResponsesApiRequest, ResponsesApiResponse, ResponsesApiStreamEvent, + ResponsesContent, ResponsesContentPart, ResponsesInput, ResponsesInputItem, + ResponsesOutputContent, ResponsesOutputItem, ResponsesUsage, + }, + }, + }, + proxy::message_history::{ + InMemoryMessageHistoryStorage, MessageHistoryStorage, StoredMessageHistory, + StoredMessageHistoryStatus, + }, + }; + + fn user_message(text: &str) -> ChatMessage { + ChatMessage { + role: "user".into(), + content: Some(MessageContent::Text(text.into())), + name: None, + tool_calls: None, + tool_call_id: None, + } + } + + fn assistant_message(text: &str) -> ChatMessage { + ChatMessage { + role: "assistant".into(), + content: Some(MessageContent::Text(text.into())), + name: None, + tool_calls: None, + tool_call_id: None, + } + } + + fn message_text(message: &ChatMessage) -> Option<&str> { + match message.content.as_ref() { + Some(MessageContent::Text(text)) => Some(text.as_str()), + _ => None, + } + } + + fn text_request(text: &str) -> ResponsesApiRequest { + ResponsesApiRequest { + background: None, + context_management: None, + conversation: None, + include: None, + model: "gpt-4.1".into(), + input: ResponsesInput::Text(text.into()), + instructions: None, + max_output_tokens: None, + max_tool_calls: None, + temperature: None, + top_p: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + prompt: None, + prompt_cache_key: None, + prompt_cache_retention: None, + reasoning: None, + safety_identifier: None, + service_tier: None, + stream: None, + stream_options: None, + metadata: None, + text: None, + top_logprobs: None, + previous_response_id: None, + store: None, + truncation: None, + replay_messages: vec![], + } + } + + fn completed_response(id: &str, text: &str) -> ResponsesApiResponse { + ResponsesApiResponse { + id: id.into(), + object: "response".into(), + created_at: 123, + model: "gpt-4.1".into(), + output: vec![ResponsesOutputItem::Message { + id: format!("{}_message_0", id), + role: "assistant".into(), + content: vec![ResponsesOutputContent::OutputText { text: text.into() }], + status: "completed".into(), + }], + status: "completed".into(), + usage: ResponsesUsage { + input_tokens: 1, + output_tokens: 2, + total_tokens: 3, + }, + metadata: Some(json!({"response_meta": true})), + previous_response_id: None, + } + } + + #[tokio::test] + async fn load_previous_messages_restores_previous_snapshot_and_lifecycle_generates_response_id() + { + let store = Arc::new(InMemoryMessageHistoryStorage::default()); + store + .put(&StoredMessageHistory { + response_id: "resp_prev".into(), + cumulative_messages: vec![user_message("old"), assistant_message("done")], + model: "gpt-4.1".into(), + created_at: 100, + status: StoredMessageHistoryStatus::Completed, + ..Default::default() + }) + .await + .unwrap(); + + let mut request = text_request("next"); + request.previous_response_id = Some("resp_prev".into()); + request.metadata = Some(json!({"trace": "abc"})); + + let state = super::init_lifecycle(&request); + let previous_messages = + super::load_previous_messages(store.as_ref(), request.previous_response_id.as_deref()) + .await + .unwrap(); + let merged_input_messages = + super::build_merged_input_messages(&request, &previous_messages).unwrap(); + + assert_eq!(state.previous_response_id.as_deref(), Some("resp_prev")); + assert_eq!(previous_messages.len(), 2); + assert_eq!(merged_input_messages.len(), 3); + assert!(state.response_id.starts_with("aresp_")); + assert_ne!(state.response_id, "resp_prev"); + assert_eq!(state.metadata.get("trace"), Some(&json!("abc"))); + } + + #[tokio::test] + async fn finalize_complete_persists_combined_snapshot() { + let store = Arc::new(InMemoryMessageHistoryStorage::default()); + let mut request = text_request("hello"); + request.metadata = Some(json!({"trace": "abc"})); + + let mut state = super::init_lifecycle(&request); + state.merged_input_messages = super::build_merged_input_messages(&request, &[]).unwrap(); + let response_id = state.response_id.clone(); + + let stored = super::accumulate_complete( + &state, + &completed_response("up_resp_1", "world"), + Usage { + input_tokens: Some(1), + output_tokens: Some(2), + total_tokens: Some(3), + ..Default::default() + }, + ) + .unwrap(); + super::persist_if_enabled(store.as_ref(), &state, &stored) + .await + .unwrap(); + + assert_eq!(stored.response_id, response_id); + assert_eq!(stored.upstream_response_id.as_deref(), Some("up_resp_1")); + assert_eq!(stored.cumulative_messages.len(), 2); + assert_eq!(message_text(&stored.cumulative_messages[0]), Some("hello")); + assert_eq!(message_text(&stored.cumulative_messages[1]), Some("world")); + assert_eq!(stored.metadata.get("trace"), Some(&json!("abc"))); + assert_eq!(stored.metadata.get("response_meta"), Some(&json!(true))); + + let loaded = store + .get_by_response_id(&response_id) + .await + .unwrap() + .unwrap(); + assert_eq!(loaded.response_id, response_id); + assert_eq!( + loaded.usage.as_ref().and_then(|usage| usage.total_tokens), + Some(3) + ); + } + + #[tokio::test] + async fn finalize_stream_success_builds_snapshot_from_events() { + let store = Arc::new(InMemoryMessageHistoryStorage::default()); + let request = text_request("hello"); + let mut state = super::init_lifecycle(&request); + state.merged_input_messages = super::build_merged_input_messages(&request, &[]).unwrap(); + let response_id = state.response_id.clone(); + + super::accumulate_stream_event( + &mut state, + &ResponsesApiStreamEvent::ResponseCreated { + response: ResponsesApiResponse { + id: "up_resp_1".into(), + object: "response".into(), + created_at: 123, + model: "gpt-4.1".into(), + output: vec![], + status: "in_progress".into(), + usage: ResponsesUsage::default(), + metadata: Some(json!({"stream": true})), + previous_response_id: None, + }, + }, + ); + super::accumulate_stream_event( + &mut state, + &ResponsesApiStreamEvent::OutputItemAdded { + output_index: 0, + item: ResponsesOutputItem::Message { + id: "up_resp_1_message_0".into(), + role: "assistant".into(), + content: vec![], + status: "in_progress".into(), + }, + }, + ); + super::accumulate_stream_event( + &mut state, + &ResponsesApiStreamEvent::ContentPartAdded { + output_index: 0, + content_index: 0, + part: ResponsesOutputContent::OutputText { + text: String::new(), + }, + }, + ); + super::accumulate_stream_event( + &mut state, + &ResponsesApiStreamEvent::OutputTextDelta { + output_index: 0, + content_index: 0, + delta: "world".into(), + }, + ); + super::accumulate_stream_event( + &mut state, + &ResponsesApiStreamEvent::OutputItemDone { + output_index: 0, + item: ResponsesOutputItem::Message { + id: "up_resp_1_message_0".into(), + role: "assistant".into(), + content: vec![ResponsesOutputContent::OutputText { + text: "world".into(), + }], + status: "completed".into(), + }, + }, + ); + + let stored = super::accumulate_stream_success( + &state, + Some(&Usage { + input_tokens: Some(1), + output_tokens: Some(2), + total_tokens: Some(3), + ..Default::default() + }), + ) + .unwrap(); + super::persist_if_enabled(store.as_ref(), &state, &stored) + .await + .unwrap(); + + assert_eq!(stored.response_id, response_id); + assert_eq!(stored.upstream_response_id.as_deref(), Some("up_resp_1")); + assert_eq!(stored.cumulative_messages.len(), 2); + assert_eq!(message_text(&stored.cumulative_messages[1]), Some("world")); + assert_eq!(stored.metadata.get("stream"), Some(&json!(true))); + + let loaded = store + .get_by_response_id(&response_id) + .await + .unwrap() + .unwrap(); + assert_eq!(loaded.response_id, response_id); + assert_eq!(loaded.status, StoredMessageHistoryStatus::Completed); + } + + #[tokio::test] + async fn load_previous_messages_returns_validation_when_previous_response_is_missing() { + let store = Arc::new(InMemoryMessageHistoryStorage::default()); + let mut request = text_request("hello"); + request.previous_response_id = Some("resp_missing".into()); + + let error = + super::load_previous_messages(store.as_ref(), request.previous_response_id.as_deref()) + .await + .unwrap_err(); + assert_eq!( + error.to_string(), + "validation: previous_response_not_found: resp_missing" + ); + } + + #[test] + fn response_output_to_chat_messages_groups_function_calls_under_assistant() { + let messages = super::response_output_to_chat_messages(&[ + ResponsesOutputItem::Message { + id: "resp_message_0".into(), + role: "assistant".into(), + content: vec![ResponsesOutputContent::OutputText { + text: "hello".into(), + }], + status: "completed".into(), + }, + ResponsesOutputItem::FunctionCall { + id: "call_1".into(), + call_id: "call_1".into(), + name: "lookup".into(), + arguments: "{\"city\":\"Paris\"}".into(), + status: "completed".into(), + }, + ]); + + assert_eq!(messages.len(), 1); + assert_eq!(messages[0].role, "assistant"); + assert_eq!(messages[0].tool_calls.as_ref().map(Vec::len), Some(1)); + } + + #[test] + fn request_input_messages_keep_tool_outputs() { + let request = ResponsesApiRequest { + input: ResponsesInput::Items(vec![ + ResponsesInputItem::Message { + role: "user".into(), + content: ResponsesContent::Text("hello".into()), + }, + ResponsesInputItem::FunctionCallOutput { + call_id: "call_1".into(), + output: "42".into(), + }, + ]), + ..text_request("ignored") + }; + + let messages = super::request_input_messages(&request).unwrap(); + assert_eq!(messages.len(), 2); + assert_eq!(messages[0].role, "user"); + assert_eq!(messages[1].role, "tool"); + assert_eq!(messages[1].tool_call_id.as_deref(), Some("call_1")); + } + + #[test] + fn request_input_messages_skip_non_persistable_input_image_file_references() { + let request = ResponsesApiRequest { + input: ResponsesInput::Items(vec![ResponsesInputItem::Message { + role: "user".into(), + content: ResponsesContent::Parts(vec![ResponsesContentPart::InputImage { + image_url: None, + file_id: Some("file_123".into()), + detail: Some("high".into()), + }]), + }]), + ..text_request("ignored") + }; + + let messages = super::request_input_messages(&request).unwrap(); + assert!(messages.is_empty()); + } +} diff --git a/src/proxy/handlers/responses/types.rs b/src/proxy/handlers/responses/types.rs index 7aa5449..4fe1f15 100644 --- a/src/proxy/handlers/responses/types.rs +++ b/src/proxy/handlers/responses/types.rs @@ -1,14 +1,15 @@ -use axum::{ - Json, - response::{IntoResponse, Response}, -}; -use http::StatusCode; +use axum::response::{IntoResponse, Response}; use thiserror::Error; use tokio::time::error::Elapsed; use crate::{ gateway::error::GatewayError, - proxy::hooks::{authorization::AuthorizationError, rate_limit::RateLimitError}, + proxy::{ + handlers::openai_error::{ + gateway_error_response, missing_model_response, timeout_response, + }, + hooks::{authorization::AuthorizationError, rate_limit::RateLimitError}, + }, }; #[derive(Debug, Error)] @@ -31,65 +32,62 @@ impl IntoResponse for ResponsesError { ResponsesError::AuthorizationError(err) => err.into_response(), ResponsesError::RateLimitError(RateLimitError::Raw(resp)) => resp, ResponsesError::GatewayError(err) => { - let status = err.status_code(); - let (message, error_type, code) = match err { - GatewayError::Provider { .. } - | GatewayError::Http(_) - | GatewayError::Stream(_) => ( - "Provider error".to_string(), - "server_error", - "provider_error", - ), - GatewayError::Internal(_) => ( - "Gateway internal error".to_string(), - "server_error", - "internal_error", - ), - _ => ( - err.to_string(), - if status.is_client_error() { - "invalid_request_error" - } else { - "server_error" - }, - "gateway_error", - ), - }; + gateway_error_response("Responses", &err, err.status_code()) + } + ResponsesError::Timeout(_) => timeout_response("Responses"), + ResponsesError::MissingModelInContext => missing_model_response("Responses"), + } + } +} + +#[cfg(test)] +mod tests { + use axum::response::IntoResponse; + use http::StatusCode; + use http_body_util::BodyExt; + use pretty_assertions::assert_eq; + use serde_json::{Value, json}; - ( - status, - Json(serde_json::json!({ + use super::ResponsesError; + use crate::gateway::error::GatewayError; + + #[tokio::test] + async fn provider_errors_keep_provider_error_shape_but_surface_message_and_details() { + let response = ResponsesError::GatewayError(GatewayError::Provider { + status: StatusCode::BAD_REQUEST, + body: json!({ + "error": { + "message": "unknown model", + "type": "invalid_request_error", + "code": "model_not_found" + } + }), + provider: "openai".into(), + retryable: false, + }) + .into_response(); + + let status = response.status(); + let body = response.into_body().collect().await.unwrap().to_bytes(); + let payload: Value = serde_json::from_slice(&body).unwrap(); + + assert_eq!(status, StatusCode::BAD_REQUEST); + assert_eq!( + payload, + json!({ + "error": { + "message": "unknown model", + "type": "server_error", + "code": "provider_error", + "details": { "error": { - "message": message, - "type": error_type, - "code": code + "message": "unknown model", + "type": "invalid_request_error", + "code": "model_not_found" } - })), - ) - .into_response() - } - ResponsesError::Timeout(_) => ( - StatusCode::GATEWAY_TIMEOUT, - Json(serde_json::json!({ - "error": { - "message": "Provider request timed out", - "type": "server_error", - "code": "request_timeout" - } - })), - ) - .into_response(), - ResponsesError::MissingModelInContext => ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": { - "message": "model missing in request context", - "type": "server_error", - "code": "internal_error" } - })), - ) - .into_response(), - } + } + }) + ); } } diff --git a/src/proxy/message_history.rs b/src/proxy/message_history.rs new file mode 100644 index 0000000..d09a161 --- /dev/null +++ b/src/proxy/message_history.rs @@ -0,0 +1,166 @@ +use std::collections::HashMap; + +use async_trait::async_trait; +use serde_json::Value; +use tokio::sync::RwLock; + +use crate::gateway::{ + error::Result, + types::{common::Usage, openai::ChatMessage}, +}; + +#[async_trait] +pub(crate) trait MessageHistoryStorage: Send + Sync + 'static { + async fn get_by_response_id(&self, response_id: &str) -> Result>; + + async fn put(&self, history: &StoredMessageHistory) -> Result<()>; +} + +#[allow(dead_code)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub(crate) enum StoredMessageHistoryStatus { + #[default] + Completed, + Failed, +} + +#[allow(dead_code)] +#[derive(Debug, Clone, Default)] +pub(crate) struct StoredMessageHistory { + pub response_id: String, + pub previous_response_id: Option, + pub upstream_response_id: Option, + pub cumulative_messages: Vec, + pub model: String, + pub created_at: u64, + pub finished_at: Option, + pub usage: Option, + pub status: StoredMessageHistoryStatus, + pub metadata: HashMap, +} + +/// In-memory message history storage intended for tests and short-lived local runs. +/// +/// This implementation keeps every stored history entry for the lifetime of the +/// process. It has no eviction policy, no size limits, and no persistence, so +/// sustained use can grow without bound and eventually OOM the process. +#[derive(Debug, Default)] +pub(crate) struct InMemoryMessageHistoryStorage { + histories: RwLock>, +} + +#[async_trait] +impl MessageHistoryStorage for InMemoryMessageHistoryStorage { + async fn get_by_response_id(&self, response_id: &str) -> Result> { + Ok(self.histories.read().await.get(response_id).cloned()) + } + + async fn put(&self, history: &StoredMessageHistory) -> Result<()> { + self.histories + .write() + .await + .insert(history.response_id.clone(), history.clone()); + Ok(()) + } +} + +#[cfg(test)] +impl InMemoryMessageHistoryStorage { + async fn delete(&self, response_id: &str) -> Result<()> { + self.histories.write().await.remove(response_id); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use pretty_assertions::assert_eq; + use serde_json::json; + + use super::{ + InMemoryMessageHistoryStorage, MessageHistoryStorage, StoredMessageHistory, + StoredMessageHistoryStatus, + }; + use crate::gateway::types::{ + common::Usage, + openai::{ChatMessage, MessageContent}, + }; + + fn sample_message(text: &str) -> ChatMessage { + ChatMessage { + role: "user".into(), + content: Some(MessageContent::Text(text.into())), + name: None, + tool_calls: None, + tool_call_id: None, + } + } + + #[tokio::test] + async fn in_memory_storage_round_trips_message_history() { + let storage = InMemoryMessageHistoryStorage::default(); + let history = StoredMessageHistory { + response_id: "resp_1".into(), + previous_response_id: Some("resp_0".into()), + upstream_response_id: Some("chatcmpl-upstream-1".into()), + cumulative_messages: vec![sample_message("hello")], + model: "gpt-test".into(), + created_at: 10, + finished_at: Some(11), + usage: Some(Usage { + input_tokens: Some(1), + output_tokens: Some(2), + total_tokens: Some(3), + ..Default::default() + }), + status: StoredMessageHistoryStatus::Completed, + metadata: HashMap::from([("trace".into(), json!("abc"))]), + }; + + storage.put(&history).await.unwrap(); + + let loaded = storage.get_by_response_id("resp_1").await.unwrap().unwrap(); + assert_eq!(loaded.response_id, "resp_1"); + assert_eq!(loaded.previous_response_id.as_deref(), Some("resp_0")); + assert_eq!( + loaded.upstream_response_id.as_deref(), + Some("chatcmpl-upstream-1") + ); + assert_eq!(loaded.cumulative_messages.len(), 1); + assert_eq!(loaded.model, "gpt-test"); + assert_eq!(loaded.metadata.get("trace"), Some(&json!("abc"))); + } + + #[tokio::test] + async fn delete_removes_entry() { + let storage = InMemoryMessageHistoryStorage::default(); + let history = StoredMessageHistory { + response_id: "resp_1".into(), + cumulative_messages: vec![sample_message("hello")], + model: "gpt-test".into(), + created_at: 10, + status: StoredMessageHistoryStatus::Completed, + ..Default::default() + }; + + storage.put(&history).await.unwrap(); + assert!( + storage + .get_by_response_id("resp_1") + .await + .unwrap() + .is_some() + ); + + storage.delete("resp_1").await.unwrap(); + assert!( + storage + .get_by_response_id("resp_1") + .await + .unwrap() + .is_none() + ); + } +} diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index c461b92..017cb29 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -1,5 +1,6 @@ mod handlers; mod hooks; +pub(crate) mod message_history; mod middlewares; mod provider; mod utils; @@ -13,6 +14,7 @@ use axum::{ routing::{get, post}, }; +use self::message_history::MessageHistoryStorage; use crate::{ config::{Config, entities::ResourceRegistry}, gateway::Gateway, @@ -24,6 +26,7 @@ pub struct AppState { config: Arc, resources: Arc, gateway: Arc, + message_history_storage: Arc, } impl AppState { @@ -31,11 +34,13 @@ impl AppState { config: Arc, resources: Arc, gateway: Arc, + message_history_storage: Arc, ) -> Self { Self { config, resources, gateway, + message_history_storage, } } @@ -46,6 +51,10 @@ impl AppState { pub fn gateway(&self) -> Arc { self.gateway.clone() } + + pub fn message_history_storage(&self) -> Arc { + self.message_history_storage.clone() + } } pub fn create_router(state: AppState) -> Router { diff --git a/tests/package.json b/tests/package.json index a666bae..39024d6 100644 --- a/tests/package.json +++ b/tests/package.json @@ -1,13 +1,14 @@ { "name": "e2e", "version": "0.0.0", + "private": true, "type": "module", "scripts": { "lint": "eslint", "test": "vitest run", "test:dev": "vitest" }, - "packageManager": "pnpm@10.32.1", + "packageManager": "pnpm@11.0.8+sha512.4c4097e1dd2d42372c4e7fa5a791ff28fc75a484c7ac192e64b1df0fdef17594ba982f9b4fed9adfb3c757846f565b799b2763fb3733d1de1bcb82cf46684912", "devDependencies": { "@anthropic-ai/sdk": "^0.88.0", "@eslint/js": "^10.0.1",