diff --git a/src/config/entities/providers-schema.json b/src/config/entities/providers-schema.json index 98d85bf..04ed1a0 100644 --- a/src/config/entities/providers-schema.json +++ b/src/config/entities/providers-schema.json @@ -11,6 +11,7 @@ "bedrock", "cohere", "deepseek", + "fireworks-ai", "gemini", "groq", "xai", @@ -54,6 +55,7 @@ "anthropic", "cohere", "deepseek", + "fireworks-ai", "gemini", "groq", "xai", diff --git a/src/config/entities/providers.rs b/src/config/entities/providers.rs index cbc03dc..86a82db 100644 --- a/src/config/entities/providers.rs +++ b/src/config/entities/providers.rs @@ -32,6 +32,8 @@ pub enum ProviderConfig { Cohere(configs::CohereProviderConfig), #[serde(rename = "deepseek")] DeepSeek(configs::DeepSeekProviderConfig), + #[serde(rename = "fireworks-ai")] + FireworksAi(configs::FireworksAiProviderConfig), #[serde(rename = "gemini")] Gemini(configs::GeminiProviderConfig), #[serde(rename = "groq")] @@ -54,6 +56,7 @@ impl ProviderConfig { Self::Bedrock(_) => identifiers::BEDROCK, Self::Cohere(_) => identifiers::COHERE, Self::DeepSeek(_) => identifiers::DEEPSEEK, + Self::FireworksAi(_) => identifiers::FIREWORKS_AI, Self::Gemini(_) => identifiers::GEMINI, Self::Groq(_) => identifiers::GROQ, Self::Xai(_) => identifiers::XAI, @@ -157,6 +160,11 @@ mod tests { "type": "cohere", "config": { "api_key": "test_key" } }), true, None)] + #[case::fireworks_ai_ok(json!({ + "name": "fireworks-primary", + "type": "fireworks-ai", + "config": { "api_key": "test_key" } + }), true, None)] #[case::openrouter_ok(json!({ "name": "openrouter-primary", "type": "openrouter", diff --git a/src/gateway/providers/fireworks.rs b/src/gateway/providers/fireworks.rs new file mode 100644 index 0000000..f3f30cb --- /dev/null +++ b/src/gateway/providers/fireworks.rs @@ -0,0 +1,171 @@ +use http::{HeaderMap, HeaderValue, header::AUTHORIZATION}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use crate::gateway::{ + error::{GatewayError, Result}, + provider_instance::ProviderAuth, + traits::{ChatTransform, EmbedTransform, ProviderCapabilities, ProviderMeta}, + types::{ + embed::{EmbedRequestBody, EmbeddingRequest}, + openai::ChatCompletionRequest, + }, +}; + +/// Fireworks AI currently uses its OpenAI-compatible inference API. +/// Docs: https://docs.fireworks.ai/tools-sdks/openai-compatibility +pub const IDENTIFIER: &str = "fireworks-ai"; + +#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)] +pub struct FireworksAiProviderConfig { + pub api_key: String, + + #[serde(skip_serializing_if = "Option::is_none")] + pub api_base: Option, +} + +pub struct FireworksAi; + +impl ProviderMeta for FireworksAi { + fn name(&self) -> &'static str { + IDENTIFIER + } + + fn default_base_url(&self) -> &'static str { + "https://api.fireworks.ai/inference/v1" + } + + fn build_auth_headers(&self, auth: &ProviderAuth) -> Result { + let mut headers = HeaderMap::new(); + let value = HeaderValue::from_str(&format!("Bearer {}", auth.api_key_for(self.name())?)) + .map_err(|error| GatewayError::Validation(error.to_string()))?; + headers.insert(AUTHORIZATION, value); + Ok(headers) + } +} + +impl ChatTransform for FireworksAi { + fn transform_request(&self, request: &ChatCompletionRequest) -> Result { + let mut body = serde_json::to_value(request) + .map_err(|error| GatewayError::Transform(error.to_string()))?; + + if let Value::Object(map) = &mut body { + // Fireworks defaults to truncating max_tokens on context overflow. + // Set the documented override so requests keep OpenAI-style error semantics. + map.entry("context_length_exceeded_behavior") + .or_insert_with(|| Value::String("error".into())); + } + + Ok(body) + } +} + +impl EmbedTransform for FireworksAi { + fn transform_embeddings_request(&self, request: &EmbeddingRequest) -> Result { + let mut body = serde_json::to_value(request) + .map_err(|error| GatewayError::Transform(error.to_string()))?; + + if let Value::Object(map) = &mut body { + // Fireworks documents dimensions, return_logits, normalize, and prompt_template, + // but not OpenAI's encoding_format or user fields on /embeddings. + map.remove("encoding_format"); + map.remove("user"); + } + + Ok(EmbedRequestBody::Json(body)) + } +} + +impl ProviderCapabilities for FireworksAi { + fn as_embed_transform(&self) -> Option<&dyn EmbedTransform> { + Some(self) + } +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + use serde_json::json; + + use super::FireworksAi; + use crate::gateway::{ + provider_instance::ProviderAuth, + traits::{ChatTransform, EmbedTransform, ProviderCapabilities, ProviderMeta}, + types::{ + embed::{EmbedRequestBody, EmbeddingRequest}, + openai::ChatCompletionRequest, + }, + }; + + #[test] + fn provider_metadata_and_urls_are_correct() { + let provider = FireworksAi; + let headers = provider + .build_auth_headers(&ProviderAuth::ApiKey("fw-key".into())) + .unwrap(); + + assert_eq!(provider.name(), "fireworks-ai"); + assert_eq!( + provider.default_base_url(), + "https://api.fireworks.ai/inference/v1" + ); + assert_eq!(headers["authorization"], "Bearer fw-key"); + assert_eq!( + provider.build_url(provider.default_base_url(), "ignored"), + "https://api.fireworks.ai/inference/v1/chat/completions" + ); + assert!(provider.as_embed_transform().is_some()); + } + + #[test] + fn transform_request_defaults_to_openai_context_length_behavior() { + let provider = FireworksAi; + let request: ChatCompletionRequest = serde_json::from_value(json!({ + "model": "accounts/fireworks/models/kimi-k2-instruct-0905", + "messages": [{"role": "user", "content": "hello"}] + })) + .unwrap(); + + let transformed = provider.transform_request(&request).unwrap(); + + assert_eq!(transformed["context_length_exceeded_behavior"], "error"); + } + + #[test] + fn transform_request_preserves_explicit_context_length_behavior() { + let provider = FireworksAi; + let request: ChatCompletionRequest = serde_json::from_value(json!({ + "model": "accounts/fireworks/models/kimi-k2-instruct-0905", + "messages": [{"role": "user", "content": "hello"}], + "context_length_exceeded_behavior": "truncate" + })) + .unwrap(); + + let transformed = provider.transform_request(&request).unwrap(); + + assert_eq!(transformed["context_length_exceeded_behavior"], "truncate"); + } + + #[test] + fn transform_embeddings_request_strips_unsupported_fields() { + let provider = FireworksAi; + let request: EmbeddingRequest = serde_json::from_value(json!({ + "model": "fireworks/qwen3-embedding-8b", + "input": ["hello"], + "dimensions": 128, + "encoding_format": "float", + "user": "user-123" + })) + .unwrap(); + + let body = provider.transform_embeddings_request(&request).unwrap(); + + match body { + EmbedRequestBody::Json(value) => { + assert_eq!(value["dimensions"], 128); + assert_eq!(value.get("encoding_format"), None); + assert_eq!(value.get("user"), None); + } + } + } +} diff --git a/src/gateway/providers/mod.rs b/src/gateway/providers/mod.rs index e366ff3..c0bdde6 100644 --- a/src/gateway/providers/mod.rs +++ b/src/gateway/providers/mod.rs @@ -3,6 +3,7 @@ pub mod azure; pub mod bedrock; pub mod cohere; pub mod deepseek; +pub mod fireworks; pub mod gemini; pub mod groq; pub mod macros; @@ -16,6 +17,7 @@ pub use azure::AzureDef; pub use bedrock::BedrockDef; pub use cohere::Cohere; pub use deepseek::DeepSeek; +pub use fireworks::FireworksAi; pub use gemini::GoogleDef; pub use groq::Groq; pub use mistral::Mistral; @@ -25,7 +27,8 @@ pub use xai::Xai; pub mod identifiers { use super::{ - anthropic, azure, bedrock, cohere, deepseek, gemini, groq, mistral, openai, openrouter, xai, + anthropic, azure, bedrock, cohere, deepseek, fireworks, gemini, groq, mistral, openai, + openrouter, xai, }; pub const ANTHROPIC: &str = anthropic::IDENTIFIER; @@ -33,6 +36,7 @@ pub mod identifiers { pub const BEDROCK: &str = bedrock::IDENTIFIER; pub const COHERE: &str = cohere::IDENTIFIER; pub const DEEPSEEK: &str = deepseek::IDENTIFIER; + pub const FIREWORKS_AI: &str = fireworks::IDENTIFIER; pub const GEMINI: &str = gemini::IDENTIFIER; pub const GROQ: &str = groq::IDENTIFIER; pub const MISTRAL: &str = mistral::IDENTIFIER; @@ -45,9 +49,9 @@ pub mod configs { pub use super::{ anthropic::AnthropicProviderConfig, azure::AzureProviderConfig, bedrock::BedrockProviderConfig, cohere::CohereProviderConfig, - deepseek::DeepSeekProviderConfig, gemini::GeminiProviderConfig, groq::GroqProviderConfig, - mistral::MistralProviderConfig, openai::OpenAIProviderConfig, - openrouter::OpenRouterProviderConfig, xai::XaiProviderConfig, + deepseek::DeepSeekProviderConfig, fireworks::FireworksAiProviderConfig, + gemini::GeminiProviderConfig, groq::GroqProviderConfig, mistral::MistralProviderConfig, + openai::OpenAIProviderConfig, openrouter::OpenRouterProviderConfig, xai::XaiProviderConfig, }; } @@ -60,6 +64,7 @@ pub fn default_provider_registry() -> Result { .register(BedrockDef)? .register(Cohere)? .register(DeepSeek)? + .register(FireworksAi)? .register(GoogleDef)? .register(Groq)? .register(Mistral)? @@ -84,6 +89,7 @@ mod tests { assert_eq!(registry.get("anthropic").unwrap().name(), "anthropic"); assert_eq!(registry.get("bedrock").unwrap().name(), "bedrock"); assert_eq!(registry.get("cohere").unwrap().name(), "cohere"); + assert_eq!(registry.get("fireworks-ai").unwrap().name(), "fireworks-ai"); assert_eq!(registry.get("gemini").unwrap().name(), "gemini"); assert_eq!(registry.get("groq").unwrap().name(), "groq"); assert_eq!(registry.get("mistral").unwrap().name(), "mistral"); diff --git a/src/proxy/provider.rs b/src/proxy/provider.rs index 8fc5e79..b07b975 100644 --- a/src/proxy/provider.rs +++ b/src/proxy/provider.rs @@ -49,6 +49,10 @@ fn provider_auth_and_base_url(config: &ProviderConfig) -> Result<(ProviderAuth, ProviderAuth::ApiKey(config.api_key.clone()), parse_base_url(config.api_base.as_deref())?, ), + ProviderConfig::FireworksAi(config) => ( + ProviderAuth::ApiKey(config.api_key.clone()), + parse_base_url(config.api_base.as_deref())?, + ), ProviderConfig::Gemini(config) => ( ProviderAuth::ApiKey(config.api_key.clone()), parse_base_url(config.api_base.as_deref())?, @@ -161,8 +165,9 @@ mod tests { use crate::{ config::entities::providers::ProviderConfig, gateway::providers::configs::{ - AzureProviderConfig, BedrockProviderConfig, CohereProviderConfig, GroqProviderConfig, - MistralProviderConfig, OpenRouterProviderConfig, XaiProviderConfig, + AzureProviderConfig, BedrockProviderConfig, CohereProviderConfig, + FireworksAiProviderConfig, GroqProviderConfig, MistralProviderConfig, + OpenRouterProviderConfig, XaiProviderConfig, }, }; @@ -237,6 +242,22 @@ mod tests { ); } + #[test] + fn provider_auth_and_base_url_returns_fireworks_api_key_and_optional_base_url() { + let config = ProviderConfig::FireworksAi(FireworksAiProviderConfig { + api_key: "fireworks-key".into(), + api_base: Some("https://api.fireworks.ai/inference/v1".into()), + }); + + let (auth, base_url_override) = provider_auth_and_base_url(&config).unwrap(); + + assert_eq!(auth.api_key_for("fireworks-ai").unwrap(), "fireworks-key"); + assert_eq!( + base_url_override.as_ref().map(Url::as_str), + Some("https://api.fireworks.ai/inference/v1") + ); + } + #[test] fn provider_auth_and_base_url_returns_groq_api_key_and_optional_base_url() { let config = ProviderConfig::Groq(GroqProviderConfig { diff --git a/ui/src/i18n/locales/en.json b/ui/src/i18n/locales/en.json index bc90abf..1763cd5 100644 --- a/ui/src/i18n/locales/en.json +++ b/ui/src/i18n/locales/en.json @@ -150,6 +150,7 @@ "openai": "OpenAI", "openrouter": "OpenRouter", "cohere": "Cohere", + "fireworks-ai": "Fireworks AI", "groq": "Groq", "xai": "xAI", "mistral": "Mistral", @@ -223,6 +224,7 @@ "openai": "OpenAI", "openrouter": "OpenRouter", "cohere": "Cohere", + "fireworks-ai": "Fireworks AI", "groq": "Groq", "xai": "xAI", "azure": "Azure OpenAI", diff --git a/ui/src/i18n/locales/zh-CN.json b/ui/src/i18n/locales/zh-CN.json index 3dbdccc..e8c97d5 100644 --- a/ui/src/i18n/locales/zh-CN.json +++ b/ui/src/i18n/locales/zh-CN.json @@ -150,6 +150,7 @@ "openai": "OpenAI", "openrouter": "OpenRouter", "cohere": "Cohere", + "fireworks-ai": "Fireworks AI", "groq": "Groq", "xai": "xAI", "mistral": "Mistral", @@ -223,6 +224,7 @@ "openai": "OpenAI", "openrouter": "OpenRouter", "cohere": "Cohere", + "fireworks-ai": "Fireworks AI", "groq": "Groq", "xai": "xAI", "azure": "Azure OpenAI", diff --git a/ui/src/lib/api/types.ts b/ui/src/lib/api/types.ts index c08d2bb..323d067 100644 --- a/ui/src/lib/api/types.ts +++ b/ui/src/lib/api/types.ts @@ -41,6 +41,7 @@ export const PROVIDER_TYPE_VARIANTS = [ 'openai', 'openrouter', 'cohere', + 'fireworks-ai', 'groq', 'xai', 'mistral', @@ -108,6 +109,11 @@ export type Provider = type: 'cohere'; config: ApiBaseProviderConfig; } + | { + name: string; + type: 'fireworks-ai'; + config: ApiBaseProviderConfig; + } | { name: string; type: 'groq';