From 039d80b45b883a3bdb564bd333b2afe2b79a2c8c Mon Sep 17 00:00:00 2001 From: bzp2010 Date: Sun, 3 May 2026 15:19:27 +0800 Subject: [PATCH] feat(provider): add xai --- src/config/entities/providers-schema.json | 2 + src/config/entities/providers.rs | 8 +++ src/gateway/providers/mod.rs | 14 ++++- src/gateway/providers/xai.rs | 69 +++++++++++++++++++++++ src/proxy/provider.rs | 22 +++++++- ui/src/i18n/locales/en.json | 2 + ui/src/i18n/locales/zh-CN.json | 2 + ui/src/lib/api/types.ts | 6 ++ 8 files changed, 121 insertions(+), 4 deletions(-) create mode 100644 src/gateway/providers/xai.rs diff --git a/src/config/entities/providers-schema.json b/src/config/entities/providers-schema.json index 2447cab..6b1d3ed 100644 --- a/src/config/entities/providers-schema.json +++ b/src/config/entities/providers-schema.json @@ -12,6 +12,7 @@ "deepseek", "gemini", "groq", + "xai", "mistral", "openai", "openrouter" @@ -53,6 +54,7 @@ "deepseek", "gemini", "groq", + "xai", "mistral", "openai", "openrouter" diff --git a/src/config/entities/providers.rs b/src/config/entities/providers.rs index 067af22..91dfb95 100644 --- a/src/config/entities/providers.rs +++ b/src/config/entities/providers.rs @@ -34,6 +34,8 @@ pub enum ProviderConfig { Gemini(configs::GeminiProviderConfig), #[serde(rename = "groq")] Groq(configs::GroqProviderConfig), + #[serde(rename = "xai")] + Xai(configs::XaiProviderConfig), #[serde(rename = "mistral")] Mistral(configs::MistralProviderConfig), #[serde(rename = "openai")] @@ -51,6 +53,7 @@ impl ProviderConfig { Self::DeepSeek(_) => identifiers::DEEPSEEK, Self::Gemini(_) => identifiers::GEMINI, Self::Groq(_) => identifiers::GROQ, + Self::Xai(_) => identifiers::XAI, Self::Mistral(_) => identifiers::MISTRAL, Self::OpenAI(_) => identifiers::OPENAI, Self::OpenRouter(_) => identifiers::OPENROUTER, @@ -156,6 +159,11 @@ mod tests { "type": "groq", "config": { "api_key": "test_key" } }), true, None)] + #[case::xai_ok(json!({ + "name": "xai-primary", + "type": "xai", + "config": { "api_key": "test_key" } + }), true, None)] #[case::mistral_ok(json!({ "name": "mistral-primary", "type": "mistral", diff --git a/src/gateway/providers/mod.rs b/src/gateway/providers/mod.rs index b093b8e..51c9934 100644 --- a/src/gateway/providers/mod.rs +++ b/src/gateway/providers/mod.rs @@ -8,6 +8,7 @@ pub mod macros; pub mod mistral; pub mod openai; pub mod openrouter; +pub mod xai; pub use anthropic::AnthropicDef; pub use azure::AzureDef; @@ -18,9 +19,12 @@ pub use groq::Groq; pub use mistral::Mistral; pub use openai::OpenAIDef; pub use openrouter::OpenRouter; +pub use xai::Xai; pub mod identifiers { - use super::{anthropic, azure, bedrock, deepseek, gemini, groq, mistral, openai, openrouter}; + use super::{ + anthropic, azure, bedrock, deepseek, gemini, groq, mistral, openai, openrouter, xai, + }; pub const ANTHROPIC: &str = anthropic::IDENTIFIER; pub const AZURE: &str = azure::IDENTIFIER; @@ -31,6 +35,7 @@ pub mod identifiers { pub const MISTRAL: &str = mistral::IDENTIFIER; pub const OPENAI: &str = openai::IDENTIFIER; pub const OPENROUTER: &str = openrouter::IDENTIFIER; + pub const XAI: &str = xai::IDENTIFIER; } pub mod configs { @@ -38,7 +43,7 @@ pub mod configs { anthropic::AnthropicProviderConfig, azure::AzureProviderConfig, bedrock::BedrockProviderConfig, deepseek::DeepSeekProviderConfig, gemini::GeminiProviderConfig, groq::GroqProviderConfig, mistral::MistralProviderConfig, - openai::OpenAIProviderConfig, openrouter::OpenRouterProviderConfig, + openai::OpenAIProviderConfig, openrouter::OpenRouterProviderConfig, xai::XaiProviderConfig, }; } @@ -54,13 +59,15 @@ pub fn default_provider_registry() -> Result { .register(Groq)? .register(Mistral)? .register(OpenAIDef)? - .register(OpenRouter)?; + .register(OpenRouter)? + .register(Xai)?; Ok(builder.build()) } #[cfg(test)] mod tests { use pretty_assertions::assert_eq; + use super::default_provider_registry; #[test] @@ -76,6 +83,7 @@ mod tests { assert_eq!(registry.get("mistral").unwrap().name(), "mistral"); assert_eq!(registry.get("deepseek").unwrap().name(), "deepseek"); assert_eq!(registry.get("openrouter").unwrap().name(), "openrouter"); + assert_eq!(registry.get("xai").unwrap().name(), "xai"); assert!(registry.get("missing").is_none()); } } diff --git a/src/gateway/providers/xai.rs b/src/gateway/providers/xai.rs new file mode 100644 index 0000000..1547ad0 --- /dev/null +++ b/src/gateway/providers/xai.rs @@ -0,0 +1,69 @@ +use serde::{Deserialize, Serialize}; + +use crate::gateway::providers::macros::provider; + +/// Provider identifier string used to look up xAI in the gateway registry. +pub const IDENTIFIER: &str = "xai"; + +/// Configuration for an xAI provider deployment. +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct XaiProviderConfig { + pub api_key: String, + + #[serde(skip_serializing_if = "Option::is_none")] + pub api_base: Option, +} + +provider!(Xai { + display_name: "xai", + base_url: "https://api.x.ai/v1", + auth: bearer, + quirks: { + unsupported_params: &["logit_bias"], + inject_stream_usage: true, + } +}); + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + use serde_json::json; + + use super::Xai; + use crate::gateway::{ + traits::{ChatTransform, ProviderMeta}, + types::openai::ChatCompletionRequest, + }; + + #[test] + fn provider_macro_expands_correctly() { + let provider = Xai; + + assert_eq!(provider.name(), "xai"); + assert_eq!(provider.default_base_url(), "https://api.x.ai/v1"); + + assert_eq!( + provider.build_url(provider.default_base_url(), "ignored"), + "https://api.x.ai/v1/chat/completions" + ); + } + + #[test] + fn transform_request_applies_xai_quirks() { + let provider = Xai; + let request: ChatCompletionRequest = serde_json::from_value(json!({ + "model": "grok-4.3", + "messages": [{"role": "user", "content": "hello"}], + "stream": true, + "max_tokens": 128, + "logit_bias": {"42": 100} + })) + .unwrap(); + + let transformed = provider.transform_request(&request).unwrap(); + + assert_eq!(transformed.get("logit_bias"), None); + assert_eq!(transformed["max_tokens"], 128); + assert_eq!(transformed["stream_options"]["include_usage"], true); + } +} diff --git a/src/proxy/provider.rs b/src/proxy/provider.rs index 56692da..6f824e2 100644 --- a/src/proxy/provider.rs +++ b/src/proxy/provider.rs @@ -53,6 +53,10 @@ fn provider_auth_and_base_url(config: &ProviderConfig) -> Result<(ProviderAuth, ProviderAuth::ApiKey(config.api_key.clone()), parse_base_url(config.api_base.as_deref())?, ), + ProviderConfig::Xai(config) => ( + ProviderAuth::ApiKey(config.api_key.clone()), + parse_base_url(config.api_base.as_deref())?, + ), ProviderConfig::Mistral(config) => ( ProviderAuth::ApiKey(config.api_key.clone()), parse_base_url(config.api_base.as_deref())?, @@ -154,7 +158,7 @@ mod tests { config::entities::providers::ProviderConfig, gateway::providers::configs::{ AzureProviderConfig, BedrockProviderConfig, GroqProviderConfig, MistralProviderConfig, - OpenRouterProviderConfig, + OpenRouterProviderConfig, XaiProviderConfig, }, }; @@ -245,6 +249,22 @@ mod tests { ); } + #[test] + fn provider_auth_and_base_url_returns_xai_api_key_and_optional_base_url() { + let config = ProviderConfig::Xai(XaiProviderConfig { + api_key: "xai-key".into(), + api_base: Some("https://api.x.ai/v1".into()), + }); + + let (auth, base_url_override) = provider_auth_and_base_url(&config).unwrap(); + + assert_eq!(auth.api_key_for("xai").unwrap(), "xai-key"); + assert_eq!( + base_url_override.as_ref().map(Url::as_str), + Some("https://api.x.ai/v1") + ); + } + #[test] fn provider_auth_and_base_url_returns_bedrock_static_credentials() { let config = ProviderConfig::Bedrock(BedrockProviderConfig { diff --git a/ui/src/i18n/locales/en.json b/ui/src/i18n/locales/en.json index b719521..dae1ce7 100644 --- a/ui/src/i18n/locales/en.json +++ b/ui/src/i18n/locales/en.json @@ -150,6 +150,7 @@ "openai": "OpenAI", "openrouter": "OpenRouter", "groq": "Groq", + "xai": "xAI", "mistral": "Mistral", "azure": "Azure OpenAI", "anthropic": "Anthropic", @@ -221,6 +222,7 @@ "openai": "OpenAI", "openrouter": "OpenRouter", "groq": "Groq", + "xai": "xAI", "azure": "Azure OpenAI", "anthropic": "Anthropic", "gemini": "Gemini", diff --git a/ui/src/i18n/locales/zh-CN.json b/ui/src/i18n/locales/zh-CN.json index e42203c..0f09dc7 100644 --- a/ui/src/i18n/locales/zh-CN.json +++ b/ui/src/i18n/locales/zh-CN.json @@ -150,6 +150,7 @@ "openai": "OpenAI", "openrouter": "OpenRouter", "groq": "Groq", + "xai": "xAI", "mistral": "Mistral", "azure": "Azure OpenAI", "anthropic": "Anthropic", @@ -221,6 +222,7 @@ "openai": "OpenAI", "openrouter": "OpenRouter", "groq": "Groq", + "xai": "xAI", "azure": "Azure OpenAI", "anthropic": "Anthropic", "gemini": "Gemini", diff --git a/ui/src/lib/api/types.ts b/ui/src/lib/api/types.ts index c87be64..110ee90 100644 --- a/ui/src/lib/api/types.ts +++ b/ui/src/lib/api/types.ts @@ -41,6 +41,7 @@ export const PROVIDER_TYPE_VARIANTS = [ 'openai', 'openrouter', 'groq', + 'xai', 'mistral', 'azure', 'anthropic', @@ -106,6 +107,11 @@ export type Provider = type: 'groq'; config: ApiBaseProviderConfig; } + | { + name: string; + type: 'xai'; + config: ApiBaseProviderConfig; + } | { name: string; type: 'mistral';