diff --git a/src/config/entities/providers-schema.json b/src/config/entities/providers-schema.json index b70a3bd..05fc583 100644 --- a/src/config/entities/providers-schema.json +++ b/src/config/entities/providers-schema.json @@ -19,7 +19,8 @@ "moonshotai", "moonshotai-cn", "openai", - "openrouter" + "openrouter", + "zhipuai" ] }, "config": { "type": "object" } @@ -65,7 +66,8 @@ "moonshotai", "moonshotai-cn", "openai", - "openrouter" + "openrouter", + "zhipuai" ] } }, diff --git a/src/config/entities/providers.rs b/src/config/entities/providers.rs index e2711d0..9649e98 100644 --- a/src/config/entities/providers.rs +++ b/src/config/entities/providers.rs @@ -50,6 +50,8 @@ pub enum ProviderConfig { OpenAI(configs::OpenAIProviderConfig), #[serde(rename = "openrouter")] OpenRouter(configs::OpenRouterProviderConfig), + #[serde(rename = "zhipuai")] + ZhipuAi(configs::ZhipuAiProviderConfig), } impl ProviderConfig { @@ -69,6 +71,7 @@ impl ProviderConfig { Self::MoonshotAiCn(_) => identifiers::MOONSHOT_AI_CN, Self::OpenAI(_) => identifiers::OPENAI, Self::OpenRouter(_) => identifiers::OPENROUTER, + Self::ZhipuAi(_) => identifiers::ZHIPUAI, } } } @@ -201,6 +204,11 @@ mod tests { "type": "moonshotai-cn", "config": { "api_key": "test_key" } }), true, None)] + #[case::zhipuai_ok(json!({ + "name": "zhipu-primary", + "type": "zhipuai", + "config": { "api_key": "test_key" } + }), true, None)] #[case::missing_type(json!({ "name": "openai-primary", "config": { "api_key": "test_key" } diff --git a/src/gateway/providers/mod.rs b/src/gateway/providers/mod.rs index 6ee6b28..e89c02c 100644 --- a/src/gateway/providers/mod.rs +++ b/src/gateway/providers/mod.rs @@ -12,6 +12,7 @@ pub mod moonshot; pub mod openai; pub mod openrouter; pub mod xai; +pub mod zhipuai; pub use anthropic::AnthropicDef; pub use azure::AzureDef; @@ -26,11 +27,12 @@ pub use moonshot::{MoonshotAi, MoonshotAiCn}; pub use openai::OpenAIDef; pub use openrouter::OpenRouter; pub use xai::Xai; +pub use zhipuai::ZhipuAi; pub mod identifiers { use super::{ anthropic, azure, bedrock, cohere, deepseek, fireworks, gemini, groq, mistral, moonshot, - openai, openrouter, xai, + openai, openrouter, xai, zhipuai, }; pub const ANTHROPIC: &str = anthropic::IDENTIFIER; @@ -47,6 +49,7 @@ pub mod identifiers { pub const OPENAI: &str = openai::IDENTIFIER; pub const OPENROUTER: &str = openrouter::IDENTIFIER; pub const XAI: &str = xai::IDENTIFIER; + pub const ZHIPUAI: &str = zhipuai::IDENTIFIER; } pub mod configs { @@ -64,6 +67,7 @@ pub mod configs { openai::OpenAIProviderConfig, openrouter::OpenRouterProviderConfig, xai::XaiProviderConfig, + zhipuai::ZhipuAiProviderConfig, }; } @@ -84,7 +88,8 @@ pub fn default_provider_registry() -> Result { .register(MoonshotAiCn)? .register(OpenAIDef)? .register(OpenRouter)? - .register(Xai)?; + .register(Xai)? + .register(ZhipuAi)?; Ok(builder.build()) } @@ -115,6 +120,7 @@ mod tests { assert_eq!(registry.get("deepseek").unwrap().name(), "deepseek"); assert_eq!(registry.get("openrouter").unwrap().name(), "openrouter"); assert_eq!(registry.get("xai").unwrap().name(), "xai"); + assert_eq!(registry.get("zhipuai").unwrap().name(), "zhipuai"); assert!(registry.get("missing").is_none()); } } diff --git a/src/gateway/providers/zhipuai.rs b/src/gateway/providers/zhipuai.rs new file mode 100644 index 0000000..b3ee484 --- /dev/null +++ b/src/gateway/providers/zhipuai.rs @@ -0,0 +1,259 @@ +use std::borrow::Cow; + +use http::{HeaderMap, HeaderValue, header::AUTHORIZATION}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use crate::gateway::{ + error::{GatewayError, Result}, + provider_instance::ProviderAuth, + traits::{ChatTransform, EmbedTransform, ProviderCapabilities, ProviderMeta}, + types::{ + embed::{EmbedRequestBody, EmbeddingRequest}, + openai::ChatCompletionRequest, + }, +}; + +/// ZhipuAI currently exposes an OpenAI-compatible base URL that already ends in +/// /api/paas/v4, so chat and embeddings both override their endpoint suffixes. +/// Docs: +/// - https://docs.bigmodel.cn/cn/guide/develop/openai/introduction.md +/// - https://docs.bigmodel.cn/api-reference/%E6%A8%A1%E5%9E%8B-api/%E6%96%87%E6%9C%AC%E5%B5%8C%E5%85%A5.md +pub const IDENTIFIER: &str = "zhipuai"; + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct ZhipuAiProviderConfig { + pub api_key: String, + + #[serde(skip_serializing_if = "Option::is_none")] + pub api_base: Option, +} + +pub struct ZhipuAi; + +impl ProviderMeta for ZhipuAi { + fn name(&self) -> &'static str { + IDENTIFIER + } + + fn default_base_url(&self) -> &'static str { + "https://open.bigmodel.cn/api/paas/v4" + } + + fn chat_endpoint_path(&self, _model: &str) -> Cow<'static, str> { + Cow::Borrowed("/chat/completions") + } + + fn build_auth_headers(&self, auth: &ProviderAuth) -> Result { + let mut headers = HeaderMap::new(); + let value = HeaderValue::from_str(&format!("Bearer {}", auth.api_key_for(self.name())?)) + .map_err(|error| GatewayError::Validation(error.to_string()))?; + headers.insert(AUTHORIZATION, value); + Ok(headers) + } +} + +impl ChatTransform for ZhipuAi { + fn transform_request(&self, request: &ChatCompletionRequest) -> Result { + let body = serde_json::to_value(request) + .map_err(|error| GatewayError::Transform(error.to_string()))?; + + validate_temperature(&body)?; + + Ok(body) + } +} + +impl EmbedTransform for ZhipuAi { + fn embeddings_endpoint_path(&self, _model: &str) -> Cow<'static, str> { + Cow::Borrowed("/embeddings") + } + + fn transform_embeddings_request(&self, request: &EmbeddingRequest) -> Result { + let mut body = serde_json::to_value(request) + .map_err(|error| GatewayError::Transform(error.to_string()))?; + + if let Value::Object(map) = &mut body { + validate_dimensions(map)?; + + // ZhipuAI embeddings document model, input, and dimensions, but not + // OpenAI's encoding_format or user fields. + map.remove("encoding_format"); + map.remove("user"); + } + + Ok(EmbedRequestBody::Json(body)) + } +} + +impl ProviderCapabilities for ZhipuAi { + fn as_embed_transform(&self) -> Option<&dyn EmbedTransform> { + Some(self) + } +} + +fn validate_temperature(body: &Value) -> Result<()> { + let Value::Object(map) = body else { + return Ok(()); + }; + + let Some(temperature) = map.get("temperature").and_then(Value::as_f64) else { + return Ok(()); + }; + + if temperature <= 0.0 || temperature > 1.0 { + return Err(GatewayError::Validation(format!( + "zhipuai requires temperature to be within (0, 1], got {temperature}" + ))); + } + + Ok(()) +} + +fn validate_dimensions(map: &serde_json::Map) -> Result<()> { + let Some(model) = map.get("model").and_then(Value::as_str) else { + return Ok(()); + }; + let Some(dimensions) = map.get("dimensions").and_then(Value::as_u64) else { + return Ok(()); + }; + + match model { + "embedding-3" if !matches!(dimensions, 256 | 512 | 1024 | 2048) => { + return Err(GatewayError::Validation(format!( + "zhipuai embedding-3 only supports dimensions 256, 512, 1024, or 2048, got {dimensions}" + ))); + } + "embedding-2" if dimensions != 1024 => { + return Err(GatewayError::Validation(format!( + "zhipuai embedding-2 uses fixed dimension 1024, got {dimensions}" + ))); + } + _ => {} + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use assert_matches::assert_matches; + use pretty_assertions::assert_eq; + use serde_json::json; + + use super::ZhipuAi; + use crate::gateway::{ + provider_instance::ProviderAuth, + traits::{ChatTransform, EmbedTransform, ProviderCapabilities, ProviderMeta}, + types::{ + embed::{EmbedRequestBody, EmbeddingRequest}, + openai::ChatCompletionRequest, + }, + }; + + #[test] + fn provider_metadata_and_urls_are_correct() { + let provider = ZhipuAi; + let headers = provider + .build_auth_headers(&ProviderAuth::ApiKey("zhipu-key".into())) + .unwrap(); + + assert_eq!(provider.name(), "zhipuai"); + assert_eq!( + provider.default_base_url(), + "https://open.bigmodel.cn/api/paas/v4" + ); + assert_eq!(headers["authorization"], "Bearer zhipu-key"); + assert_eq!( + provider.build_url(provider.default_base_url(), "glm-5.1"), + "https://open.bigmodel.cn/api/paas/v4/chat/completions" + ); + assert_eq!( + provider.embeddings_endpoint_path("embedding-3"), + "/embeddings" + ); + assert!(provider.as_embed_transform().is_some()); + } + + #[test] + fn transform_request_rejects_zero_temperature() { + let provider = ZhipuAi; + let request: ChatCompletionRequest = serde_json::from_value(json!({ + "model": "glm-5.1", + "messages": [{"role": "user", "content": "hello"}], + "temperature": 0.0, + "thinking": {"type": "enabled"}, + "tool_stream": true + })) + .unwrap(); + + let error = provider.transform_request(&request).unwrap_err(); + + assert_matches!( + error, + crate::gateway::error::GatewayError::Validation(message) + if message.contains("temperature") && message.contains("(0, 1]") + ); + } + + #[test] + fn transform_request_preserves_zhipu_extensions() { + let provider = ZhipuAi; + let request: ChatCompletionRequest = serde_json::from_value(json!({ + "model": "glm-4.7", + "messages": [{"role": "user", "content": "hello"}], + "temperature": 0.8, + "thinking": {"type": "enabled"}, + "tool_stream": true + })) + .unwrap(); + + let transformed = provider.transform_request(&request).unwrap(); + + assert_eq!(transformed["temperature"], 0.8); + assert_eq!(transformed["thinking"]["type"], "enabled"); + assert_eq!(transformed["tool_stream"], true); + } + + #[test] + fn transform_embeddings_request_strips_openai_only_fields() { + let provider = ZhipuAi; + let request: EmbeddingRequest = serde_json::from_value(json!({ + "model": "embedding-3", + "input": ["hello"], + "dimensions": 512, + "encoding_format": "float", + "user": "user-123" + })) + .unwrap(); + + let body = provider.transform_embeddings_request(&request).unwrap(); + + match body { + EmbedRequestBody::Json(value) => { + assert_eq!(value["dimensions"], 512); + assert_eq!(value.get("encoding_format"), None); + assert_eq!(value.get("user"), None); + } + } + } + + #[test] + fn transform_embeddings_request_rejects_invalid_embedding_dimensions() { + let provider = ZhipuAi; + let request: EmbeddingRequest = serde_json::from_value(json!({ + "model": "embedding-3", + "input": ["hello"], + "dimensions": 768 + })) + .unwrap(); + + let error = provider.transform_embeddings_request(&request).unwrap_err(); + + assert_matches!( + error, + crate::gateway::error::GatewayError::Validation(message) + if message.contains("embedding-3") && message.contains("256") + ); + } +} diff --git a/src/proxy/provider.rs b/src/proxy/provider.rs index 4623b6b..b73ea58 100644 --- a/src/proxy/provider.rs +++ b/src/proxy/provider.rs @@ -85,6 +85,10 @@ fn provider_auth_and_base_url(config: &ProviderConfig) -> Result<(ProviderAuth, ProviderAuth::ApiKey(config.api_key.clone()), parse_base_url(config.api_base.as_deref())?, ), + ProviderConfig::ZhipuAi(config) => ( + ProviderAuth::ApiKey(config.api_key.clone()), + parse_base_url(config.api_base.as_deref())?, + ), }; Ok((auth, base_url_override)) @@ -176,7 +180,7 @@ mod tests { AzureProviderConfig, BedrockProviderConfig, CohereProviderConfig, FireworksAiProviderConfig, GroqProviderConfig, MistralProviderConfig, MoonshotAiCnProviderConfig, MoonshotAiProviderConfig, OpenRouterProviderConfig, - XaiProviderConfig, + XaiProviderConfig, ZhipuAiProviderConfig, }, }; @@ -350,6 +354,22 @@ mod tests { ); } + #[test] + fn provider_auth_and_base_url_returns_zhipuai_api_key_and_optional_base_url() { + let config = ProviderConfig::ZhipuAi(ZhipuAiProviderConfig { + api_key: "zhipu-key".into(), + api_base: Some("https://open.bigmodel.cn/api/paas/v4".into()), + }); + + let (auth, base_url_override) = provider_auth_and_base_url(&config).unwrap(); + + assert_eq!(auth.api_key_for("zhipuai").unwrap(), "zhipu-key"); + assert_eq!( + base_url_override.as_ref().map(Url::as_str), + Some("https://open.bigmodel.cn/api/paas/v4") + ); + } + #[test] fn provider_auth_and_base_url_returns_bedrock_static_credentials() { let config = ProviderConfig::Bedrock(BedrockProviderConfig { diff --git a/ui/src/i18n/locales/en.json b/ui/src/i18n/locales/en.json index f338721..93b849b 100644 --- a/ui/src/i18n/locales/en.json +++ b/ui/src/i18n/locales/en.json @@ -156,6 +156,7 @@ "mistral": "Mistral", "moonshotai": "Moonshot AI", "moonshotai-cn": "Moonshot AI (CN)", + "zhipuai": "ZhipuAI", "azure": "Azure OpenAI", "anthropic": "Anthropic", "gemini": "Gemini", @@ -231,6 +232,7 @@ "xai": "xAI", "moonshotai": "Moonshot AI", "moonshotai-cn": "Moonshot AI (CN)", + "zhipuai": "ZhipuAI", "azure": "Azure OpenAI", "anthropic": "Anthropic", "gemini": "Gemini", diff --git a/ui/src/i18n/locales/zh-CN.json b/ui/src/i18n/locales/zh-CN.json index 1726d87..23631a0 100644 --- a/ui/src/i18n/locales/zh-CN.json +++ b/ui/src/i18n/locales/zh-CN.json @@ -156,6 +156,7 @@ "mistral": "Mistral", "moonshotai": "Moonshot AI", "moonshotai-cn": "Moonshot AI (CN)", + "zhipuai": "智谱", "azure": "Azure OpenAI", "anthropic": "Anthropic", "gemini": "Gemini", @@ -231,6 +232,7 @@ "xai": "xAI", "moonshotai": "Moonshot AI", "moonshotai-cn": "Moonshot AI (CN)", + "zhipuai": "ZhipuAI", "azure": "Azure OpenAI", "anthropic": "Anthropic", "gemini": "Gemini", diff --git a/ui/src/lib/api/types.ts b/ui/src/lib/api/types.ts index 8006264..2c6b9d8 100644 --- a/ui/src/lib/api/types.ts +++ b/ui/src/lib/api/types.ts @@ -47,6 +47,7 @@ export const PROVIDER_TYPE_VARIANTS = [ 'mistral', 'moonshotai', 'moonshotai-cn', + 'zhipuai', 'azure', 'anthropic', 'gemini', @@ -141,6 +142,11 @@ export type Provider = type: 'moonshotai-cn'; config: ApiBaseProviderConfig; } + | { + name: string; + type: 'zhipuai'; + config: ApiBaseProviderConfig; + } | { name: string; type: 'bedrock';