diff --git a/src/config/entities/providers-schema.json b/src/config/entities/providers-schema.json index 04ed1a0..b70a3bd 100644 --- a/src/config/entities/providers-schema.json +++ b/src/config/entities/providers-schema.json @@ -16,6 +16,8 @@ "groq", "xai", "mistral", + "moonshotai", + "moonshotai-cn", "openai", "openrouter" ] @@ -60,6 +62,8 @@ "groq", "xai", "mistral", + "moonshotai", + "moonshotai-cn", "openai", "openrouter" ] diff --git a/src/config/entities/providers.rs b/src/config/entities/providers.rs index 86a82db..e2711d0 100644 --- a/src/config/entities/providers.rs +++ b/src/config/entities/providers.rs @@ -42,6 +42,10 @@ pub enum ProviderConfig { Xai(configs::XaiProviderConfig), #[serde(rename = "mistral")] Mistral(configs::MistralProviderConfig), + #[serde(rename = "moonshotai")] + MoonshotAi(configs::MoonshotAiProviderConfig), + #[serde(rename = "moonshotai-cn")] + MoonshotAiCn(configs::MoonshotAiCnProviderConfig), #[serde(rename = "openai")] OpenAI(configs::OpenAIProviderConfig), #[serde(rename = "openrouter")] @@ -61,6 +65,8 @@ impl ProviderConfig { Self::Groq(_) => identifiers::GROQ, Self::Xai(_) => identifiers::XAI, Self::Mistral(_) => identifiers::MISTRAL, + Self::MoonshotAi(_) => identifiers::MOONSHOT_AI, + Self::MoonshotAiCn(_) => identifiers::MOONSHOT_AI_CN, Self::OpenAI(_) => identifiers::OPENAI, Self::OpenRouter(_) => identifiers::OPENROUTER, } @@ -185,6 +191,16 @@ mod tests { "type": "mistral", "config": { "api_key": "test_key" } }), true, None)] + #[case::moonshotai_ok(json!({ + "name": "moonshot-primary", + "type": "moonshotai", + "config": { "api_key": "test_key" } + }), true, None)] + #[case::moonshotai_cn_ok(json!({ + "name": "moonshot-cn-primary", + "type": "moonshotai-cn", + "config": { "api_key": "test_key" } + }), true, None)] #[case::missing_type(json!({ "name": "openai-primary", "config": { "api_key": "test_key" } diff --git a/src/gateway/providers/mod.rs b/src/gateway/providers/mod.rs index c0bdde6..6ee6b28 100644 --- a/src/gateway/providers/mod.rs +++ b/src/gateway/providers/mod.rs @@ -8,6 +8,7 @@ pub mod gemini; pub mod groq; pub mod macros; pub mod mistral; +pub mod moonshot; pub mod openai; pub mod openrouter; pub mod xai; @@ -21,14 +22,15 @@ pub use fireworks::FireworksAi; pub use gemini::GoogleDef; pub use groq::Groq; pub use mistral::Mistral; +pub use moonshot::{MoonshotAi, MoonshotAiCn}; pub use openai::OpenAIDef; pub use openrouter::OpenRouter; pub use xai::Xai; pub mod identifiers { use super::{ - anthropic, azure, bedrock, cohere, deepseek, fireworks, gemini, groq, mistral, openai, - openrouter, xai, + anthropic, azure, bedrock, cohere, deepseek, fireworks, gemini, groq, mistral, moonshot, + openai, openrouter, xai, }; pub const ANTHROPIC: &str = anthropic::IDENTIFIER; @@ -40,6 +42,8 @@ pub mod identifiers { pub const GEMINI: &str = gemini::IDENTIFIER; pub const GROQ: &str = groq::IDENTIFIER; pub const MISTRAL: &str = mistral::IDENTIFIER; + pub const MOONSHOT_AI: &str = moonshot::IDENTIFIER; + pub const MOONSHOT_AI_CN: &str = moonshot::CN_IDENTIFIER; pub const OPENAI: &str = openai::IDENTIFIER; pub const OPENROUTER: &str = openrouter::IDENTIFIER; pub const XAI: &str = xai::IDENTIFIER; @@ -47,11 +51,19 @@ pub mod identifiers { pub mod configs { pub use super::{ - anthropic::AnthropicProviderConfig, azure::AzureProviderConfig, - bedrock::BedrockProviderConfig, cohere::CohereProviderConfig, - deepseek::DeepSeekProviderConfig, fireworks::FireworksAiProviderConfig, - gemini::GeminiProviderConfig, groq::GroqProviderConfig, mistral::MistralProviderConfig, - openai::OpenAIProviderConfig, openrouter::OpenRouterProviderConfig, xai::XaiProviderConfig, + anthropic::AnthropicProviderConfig, + azure::AzureProviderConfig, + bedrock::BedrockProviderConfig, + cohere::CohereProviderConfig, + deepseek::DeepSeekProviderConfig, + fireworks::FireworksAiProviderConfig, + gemini::GeminiProviderConfig, + groq::GroqProviderConfig, + mistral::MistralProviderConfig, + moonshot::{MoonshotAiCnProviderConfig, MoonshotAiProviderConfig}, + openai::OpenAIProviderConfig, + openrouter::OpenRouterProviderConfig, + xai::XaiProviderConfig, }; } @@ -68,6 +80,8 @@ pub fn default_provider_registry() -> Result { .register(GoogleDef)? .register(Groq)? .register(Mistral)? + .register(MoonshotAi)? + .register(MoonshotAiCn)? .register(OpenAIDef)? .register(OpenRouter)? .register(Xai)?; @@ -93,6 +107,11 @@ mod tests { assert_eq!(registry.get("gemini").unwrap().name(), "gemini"); assert_eq!(registry.get("groq").unwrap().name(), "groq"); assert_eq!(registry.get("mistral").unwrap().name(), "mistral"); + assert_eq!(registry.get("moonshotai").unwrap().name(), "moonshotai"); + assert_eq!( + registry.get("moonshotai-cn").unwrap().name(), + "moonshotai-cn" + ); assert_eq!(registry.get("deepseek").unwrap().name(), "deepseek"); assert_eq!(registry.get("openrouter").unwrap().name(), "openrouter"); assert_eq!(registry.get("xai").unwrap().name(), "xai"); diff --git a/src/gateway/providers/moonshot.rs b/src/gateway/providers/moonshot.rs new file mode 100644 index 0000000..e95882c --- /dev/null +++ b/src/gateway/providers/moonshot.rs @@ -0,0 +1,403 @@ +//! Moonshot AI currently publishes a single official OpenAI-compatible endpoint, +//! so both moonshotai and moonshotai-cn share the same wire behavior. +//! +//! Docs: +//! - https://platform.kimi.com/docs/api/overview.md +//! - https://platform.kimi.com/docs/api/chat.md +//! - https://platform.kimi.com/docs/api/models-overview.md +//! - https://platform.kimi.com/docs/guide/migrating-from-openai-to-kimi.md + +use http::{HeaderMap, HeaderValue, header::AUTHORIZATION}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use crate::gateway::{ + error::{GatewayError, Result}, + provider_instance::ProviderAuth, + traits::{ChatTransform, ProviderCapabilities, ProviderMeta}, + types::openai::ChatCompletionRequest, +}; + +pub const IDENTIFIER: &str = "moonshotai"; +pub const CN_IDENTIFIER: &str = "moonshotai-cn"; + +const DEFAULT_BASE_URL: &str = "https://api.moonshot.cn/v1"; + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct MoonshotAiProviderConfig { + pub api_key: String, + + #[serde(skip_serializing_if = "Option::is_none")] + pub api_base: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct MoonshotAiCnProviderConfig { + pub api_key: String, + + #[serde(skip_serializing_if = "Option::is_none")] + pub api_base: Option, +} + +pub struct MoonshotAi; +pub struct MoonshotAiCn; + +impl ProviderMeta for MoonshotAi { + fn name(&self) -> &'static str { + IDENTIFIER + } + + fn default_base_url(&self) -> &'static str { + DEFAULT_BASE_URL + } + + fn build_auth_headers(&self, auth: &ProviderAuth) -> Result { + build_auth_headers(self.name(), auth) + } +} + +impl ProviderMeta for MoonshotAiCn { + fn name(&self) -> &'static str { + CN_IDENTIFIER + } + + fn default_base_url(&self) -> &'static str { + DEFAULT_BASE_URL + } + + fn build_auth_headers(&self, auth: &ProviderAuth) -> Result { + build_auth_headers(self.name(), auth) + } +} + +impl ChatTransform for MoonshotAi { + fn transform_request(&self, request: &ChatCompletionRequest) -> Result { + transform_request(request) + } +} + +impl ChatTransform for MoonshotAiCn { + fn transform_request(&self, request: &ChatCompletionRequest) -> Result { + transform_request(request) + } +} + +impl ProviderCapabilities for MoonshotAi {} + +impl ProviderCapabilities for MoonshotAiCn {} + +fn build_auth_headers(identifier: &str, auth: &ProviderAuth) -> Result { + let mut headers = HeaderMap::new(); + let value = HeaderValue::from_str(&format!("Bearer {}", auth.api_key_for(identifier)?)) + .map_err(|error| GatewayError::Validation(error.to_string()))?; + headers.insert(AUTHORIZATION, value); + Ok(headers) +} + +fn transform_request(request: &ChatCompletionRequest) -> Result { + let mut body = serde_json::to_value(request) + .map_err(|error| GatewayError::Transform(error.to_string()))?; + + let Value::Object(map) = &mut body else { + return Ok(body); + }; + + convert_legacy_functions(map)?; + convert_legacy_function_call(map)?; + validate_tool_choice(map)?; + + let model = map + .get("model") + .and_then(Value::as_str) + .ok_or_else(|| GatewayError::Validation("moonshot providers require a string model".into()))? + .to_string(); + + apply_model_specific_quirks(map, model.as_str()); + validate_generic_constraints(map)?; + + Ok(body) +} + +fn convert_legacy_functions(map: &mut serde_json::Map) -> Result<()> { + let Some(functions) = map.remove("functions") else { + return Ok(()); + }; + + if map.contains_key("tools") { + return Ok(()); + } + + match functions { + Value::Null => Ok(()), + Value::Array(functions) => { + let tools = functions + .into_iter() + .map(|function| { + let mut tool = serde_json::Map::new(); + tool.insert("type".into(), Value::String("function".into())); + tool.insert("function".into(), function); + Value::Object(tool) + }) + .collect(); + map.insert("tools".into(), Value::Array(tools)); + Ok(()) + } + _ => Err(GatewayError::Validation( + "moonshot providers expect legacy functions to be an array".into(), + )), + } +} + +fn convert_legacy_function_call(map: &mut serde_json::Map) -> Result<()> { + let Some(function_call) = map.remove("function_call") else { + return Ok(()); + }; + + if map.contains_key("tool_choice") { + return Ok(()); + } + + match function_call { + Value::Null => Ok(()), + Value::String(mode) if matches!(mode.as_str(), "none" | "auto") => { + map.insert("tool_choice".into(), Value::String(mode)); + Ok(()) + } + Value::String(mode) => Err(GatewayError::Validation(format!( + "moonshot providers only document tool_choice values \"none\" and \"auto\"; unsupported legacy function_call value \"{mode}\"" + ))), + Value::Object(_) => Err(GatewayError::Validation( + "moonshot providers do not document forced function_call objects; use tools with tool_choice set to \"auto\" or \"none\"".into(), + )), + _ => Err(GatewayError::Validation( + "moonshot providers expect legacy function_call to be a string or null".into(), + )), + } +} + +fn validate_tool_choice(map: &serde_json::Map) -> Result<()> { + let Some(tool_choice) = map.get("tool_choice") else { + return Ok(()); + }; + + match tool_choice { + Value::String(mode) if matches!(mode.as_str(), "none" | "auto") => Ok(()), + Value::String(mode) if mode == "required" => Err(GatewayError::Validation( + "moonshot providers do not support tool_choice=\"required\"".into(), + )), + Value::String(mode) => Err(GatewayError::Validation(format!( + "moonshot providers only document tool_choice values \"none\" and \"auto\", got \"{mode}\"" + ))), + Value::Object(_) => Err(GatewayError::Validation( + "moonshot providers do not document object-form tool_choice".into(), + )), + Value::Null => Ok(()), + _ => Err(GatewayError::Validation( + "moonshot providers expect tool_choice to be a string, object, or null".into(), + )), + } +} + +fn apply_model_specific_quirks(map: &mut serde_json::Map, model: &str) { + match model { + // kimi-k2.6 exposes thinking as an extra-body extension and documents + // sampling controls as fixed rather than user-tunable. + "kimi-k2.6" => { + map.remove("temperature"); + map.remove("top_p"); + map.remove("n"); + map.remove("presence_penalty"); + map.remove("frequency_penalty"); + } + // kimi-k2.5 documents a fixed temperature that depends on thinking mode, + // so omit user-supplied temperature and let the model choose the correct value. + "kimi-k2.5" => { + map.remove("temperature"); + } + _ => {} + } +} + +fn validate_generic_constraints(map: &serde_json::Map) -> Result<()> { + if let Some(temperature) = map.get("temperature").and_then(Value::as_f64) + && !(0.0..=1.0).contains(&temperature) + { + return Err(GatewayError::Validation(format!( + "moonshot providers require temperature to be within [0, 1], got {temperature}" + ))); + } + + if let Some(top_p) = map.get("top_p").and_then(Value::as_f64) + && !(0.0..=1.0).contains(&top_p) + { + return Err(GatewayError::Validation(format!( + "moonshot providers require top_p to be within [0, 1], got {top_p}" + ))); + } + + if let Some(n) = map.get("n").and_then(Value::as_u64) { + if !(1..=5).contains(&n) { + return Err(GatewayError::Validation(format!( + "moonshot providers require n to be within [1, 5], got {n}" + ))); + } + + if map.get("temperature").and_then(Value::as_f64) == Some(0.0) && n > 1 { + return Err(GatewayError::Validation( + "moonshot providers reject n > 1 when temperature is 0".into(), + )); + } + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use assert_matches::assert_matches; + use pretty_assertions::assert_eq; + use serde_json::json; + + use super::{MoonshotAi, MoonshotAiCn}; + use crate::gateway::{ + provider_instance::ProviderAuth, + traits::{ChatTransform, ProviderMeta}, + types::openai::ChatCompletionRequest, + }; + + #[test] + fn provider_metadata_and_urls_are_correct() { + let global = MoonshotAi; + let cn = MoonshotAiCn; + let global_headers = global + .build_auth_headers(&ProviderAuth::ApiKey("moonshot-key".into())) + .unwrap(); + let cn_headers = cn + .build_auth_headers(&ProviderAuth::ApiKey("moonshot-cn-key".into())) + .unwrap(); + + assert_eq!(global.name(), "moonshotai"); + assert_eq!(cn.name(), "moonshotai-cn"); + assert_eq!(global.default_base_url(), "https://api.moonshot.cn/v1"); + assert_eq!(cn.default_base_url(), "https://api.moonshot.cn/v1"); + assert_eq!(global_headers["authorization"], "Bearer moonshot-key"); + assert_eq!(cn_headers["authorization"], "Bearer moonshot-cn-key"); + assert_eq!( + global.build_url(global.default_base_url(), "ignored"), + "https://api.moonshot.cn/v1/chat/completions" + ); + assert_eq!( + cn.build_url(cn.default_base_url(), "ignored"), + "https://api.moonshot.cn/v1/chat/completions" + ); + } + + #[test] + fn transform_request_converts_legacy_function_fields() { + let provider = MoonshotAi; + let request: ChatCompletionRequest = serde_json::from_value(json!({ + "model": "moonshot-v1-128k", + "messages": [{"role": "user", "content": "hello"}], + "functions": [ + { + "name": "search", + "description": "Search the web", + "parameters": {"type": "object", "properties": {}} + } + ], + "function_call": "auto" + })) + .unwrap(); + + let transformed = provider.transform_request(&request).unwrap(); + + assert_eq!(transformed.get("functions"), None); + assert_eq!(transformed.get("function_call"), None); + assert_eq!(transformed["tool_choice"], "auto"); + assert_eq!(transformed["tools"][0]["type"], "function"); + assert_eq!(transformed["tools"][0]["function"]["name"], "search"); + } + + #[test] + fn transform_request_rejects_required_tool_choice() { + let provider = MoonshotAi; + let request: ChatCompletionRequest = serde_json::from_value(json!({ + "model": "moonshot-v1-128k", + "messages": [{"role": "user", "content": "hello"}], + "tool_choice": "required" + })) + .unwrap(); + + let error = provider.transform_request(&request).unwrap_err(); + + assert_matches!( + error, + crate::gateway::error::GatewayError::Validation(message) + if message.contains("tool_choice") && message.contains("required") + ); + } + + #[test] + fn transform_request_strips_fixed_k26_sampling_fields() { + let provider = MoonshotAi; + let request: ChatCompletionRequest = serde_json::from_value(json!({ + "model": "kimi-k2.6", + "messages": [{"role": "user", "content": "hello"}], + "temperature": 0.8, + "top_p": 0.7, + "n": 3, + "presence_penalty": 0.5, + "frequency_penalty": 0.5, + "thinking": {"type": "disabled"} + })) + .unwrap(); + + let transformed = provider.transform_request(&request).unwrap(); + + assert_eq!(transformed.get("temperature"), None); + assert_eq!(transformed.get("top_p"), None); + assert_eq!(transformed.get("n"), None); + assert_eq!(transformed.get("presence_penalty"), None); + assert_eq!(transformed.get("frequency_penalty"), None); + assert_eq!(transformed["thinking"]["type"], "disabled"); + } + + #[test] + fn transform_request_rejects_temperature_above_one() { + let provider = MoonshotAiCn; + let request: ChatCompletionRequest = serde_json::from_value(json!({ + "model": "moonshot-v1-128k", + "messages": [{"role": "user", "content": "hello"}], + "temperature": 1.5 + })) + .unwrap(); + + let error = provider.transform_request(&request).unwrap_err(); + + assert_matches!( + error, + crate::gateway::error::GatewayError::Validation(message) + if message.contains("temperature") && message.contains("[0, 1]") + ); + } + + #[test] + fn transform_request_rejects_multiple_choices_when_temperature_is_zero() { + let provider = MoonshotAi; + let request: ChatCompletionRequest = serde_json::from_value(json!({ + "model": "moonshot-v1-128k", + "messages": [{"role": "user", "content": "hello"}], + "temperature": 0.0, + "n": 2 + })) + .unwrap(); + + let error = provider.transform_request(&request).unwrap_err(); + + assert_matches!( + error, + crate::gateway::error::GatewayError::Validation(message) + if message.contains("n > 1") && message.contains("temperature is 0") + ); + } +} diff --git a/src/proxy/provider.rs b/src/proxy/provider.rs index b07b975..4623b6b 100644 --- a/src/proxy/provider.rs +++ b/src/proxy/provider.rs @@ -69,6 +69,14 @@ fn provider_auth_and_base_url(config: &ProviderConfig) -> Result<(ProviderAuth, ProviderAuth::ApiKey(config.api_key.clone()), parse_base_url(config.api_base.as_deref())?, ), + ProviderConfig::MoonshotAi(config) => ( + ProviderAuth::ApiKey(config.api_key.clone()), + parse_base_url(config.api_base.as_deref())?, + ), + ProviderConfig::MoonshotAiCn(config) => ( + ProviderAuth::ApiKey(config.api_key.clone()), + parse_base_url(config.api_base.as_deref())?, + ), ProviderConfig::OpenAI(config) => ( ProviderAuth::ApiKey(config.api_key.clone()), parse_base_url(config.api_base.as_deref())?, @@ -167,7 +175,8 @@ mod tests { gateway::providers::configs::{ AzureProviderConfig, BedrockProviderConfig, CohereProviderConfig, FireworksAiProviderConfig, GroqProviderConfig, MistralProviderConfig, - OpenRouterProviderConfig, XaiProviderConfig, + MoonshotAiCnProviderConfig, MoonshotAiProviderConfig, OpenRouterProviderConfig, + XaiProviderConfig, }, }; @@ -290,6 +299,41 @@ mod tests { ); } + #[test] + fn provider_auth_and_base_url_returns_moonshot_api_key_and_optional_base_url() { + let config = ProviderConfig::MoonshotAi(MoonshotAiProviderConfig { + api_key: "moonshot-key".into(), + api_base: Some("https://api.moonshot.cn/v1".into()), + }); + + let (auth, base_url_override) = provider_auth_and_base_url(&config).unwrap(); + + assert_eq!(auth.api_key_for("moonshotai").unwrap(), "moonshot-key"); + assert_eq!( + base_url_override.as_ref().map(Url::as_str), + Some("https://api.moonshot.cn/v1") + ); + } + + #[test] + fn provider_auth_and_base_url_returns_moonshot_cn_api_key_and_optional_base_url() { + let config = ProviderConfig::MoonshotAiCn(MoonshotAiCnProviderConfig { + api_key: "moonshot-cn-key".into(), + api_base: Some("https://api.moonshot.cn/v1".into()), + }); + + let (auth, base_url_override) = provider_auth_and_base_url(&config).unwrap(); + + assert_eq!( + auth.api_key_for("moonshotai-cn").unwrap(), + "moonshot-cn-key" + ); + assert_eq!( + base_url_override.as_ref().map(Url::as_str), + Some("https://api.moonshot.cn/v1") + ); + } + #[test] fn provider_auth_and_base_url_returns_xai_api_key_and_optional_base_url() { let config = ProviderConfig::Xai(XaiProviderConfig { diff --git a/ui/src/i18n/locales/en.json b/ui/src/i18n/locales/en.json index 1763cd5..f338721 100644 --- a/ui/src/i18n/locales/en.json +++ b/ui/src/i18n/locales/en.json @@ -154,6 +154,8 @@ "groq": "Groq", "xai": "xAI", "mistral": "Mistral", + "moonshotai": "Moonshot AI", + "moonshotai-cn": "Moonshot AI (CN)", "azure": "Azure OpenAI", "anthropic": "Anthropic", "gemini": "Gemini", @@ -227,6 +229,8 @@ "fireworks-ai": "Fireworks AI", "groq": "Groq", "xai": "xAI", + "moonshotai": "Moonshot AI", + "moonshotai-cn": "Moonshot AI (CN)", "azure": "Azure OpenAI", "anthropic": "Anthropic", "gemini": "Gemini", diff --git a/ui/src/i18n/locales/zh-CN.json b/ui/src/i18n/locales/zh-CN.json index e8c97d5..1726d87 100644 --- a/ui/src/i18n/locales/zh-CN.json +++ b/ui/src/i18n/locales/zh-CN.json @@ -154,6 +154,8 @@ "groq": "Groq", "xai": "xAI", "mistral": "Mistral", + "moonshotai": "Moonshot AI", + "moonshotai-cn": "Moonshot AI (CN)", "azure": "Azure OpenAI", "anthropic": "Anthropic", "gemini": "Gemini", @@ -227,6 +229,8 @@ "fireworks-ai": "Fireworks AI", "groq": "Groq", "xai": "xAI", + "moonshotai": "Moonshot AI", + "moonshotai-cn": "Moonshot AI (CN)", "azure": "Azure OpenAI", "anthropic": "Anthropic", "gemini": "Gemini", diff --git a/ui/src/lib/api/types.ts b/ui/src/lib/api/types.ts index 323d067..8006264 100644 --- a/ui/src/lib/api/types.ts +++ b/ui/src/lib/api/types.ts @@ -45,6 +45,8 @@ export const PROVIDER_TYPE_VARIANTS = [ 'groq', 'xai', 'mistral', + 'moonshotai', + 'moonshotai-cn', 'azure', 'anthropic', 'gemini', @@ -129,6 +131,16 @@ export type Provider = type: 'mistral'; config: ApiBaseProviderConfig; } + | { + name: string; + type: 'moonshotai'; + config: ApiBaseProviderConfig; + } + | { + name: string; + type: 'moonshotai-cn'; + config: ApiBaseProviderConfig; + } | { name: string; type: 'bedrock';