From be10929ad9549f0138e8aa310df862447ef13ac0 Mon Sep 17 00:00:00 2001 From: bzp2010 Date: Sun, 3 May 2026 13:47:00 +0800 Subject: [PATCH 1/2] chore(test): move to pretty-assertions --- src/config/entities/apikeys.rs | 1 + src/config/entities/mod.rs | 1 + src/config/entities/models.rs | 1 + src/config/entities/providers.rs | 1 + src/gateway/error.rs | 1 + src/gateway/formats/anthropic_messages.rs | 1 + src/gateway/formats/openai/mod.rs | 1 + src/gateway/gateway.rs | 1 + src/gateway/provider_instance.rs | 1 + src/gateway/providers/anthropic/mod.rs | 1 + src/gateway/providers/anthropic/transform.rs | 1 + src/gateway/providers/bedrock/transform.rs | 1 + src/gateway/providers/groq.rs | 7 ++++--- src/gateway/providers/macros.rs | 1 + src/gateway/providers/mistral.rs | 10 ++++++---- src/gateway/providers/mod.rs | 1 + src/gateway/providers/openai.rs | 1 + src/gateway/providers/openrouter.rs | 9 +++++---- src/gateway/streams/bridged.rs | 1 + src/gateway/streams/hub.rs | 1 + src/gateway/streams/native.rs | 1 + src/gateway/streams/reader/aws_event_stream.rs | 1 + src/gateway/streams/reader/sse.rs | 1 + src/gateway/traits/chat_format.rs | 1 + src/gateway/traits/provider.rs | 9 +++++---- src/gateway/types/anthropic.rs | 1 + src/gateway/types/common.rs | 1 + src/gateway/types/embed.rs | 1 + src/gateway/types/openai/mod.rs | 1 + src/gateway/types/openai/responses.rs | 1 + .../handlers/chat_completions/span_attributes/tests.rs | 1 + src/proxy/handlers/messages/types.rs | 1 + src/proxy/hooks/rate_limit/concurrent/local.rs | 1 + src/proxy/hooks/rate_limit/concurrent/mod.rs | 1 + src/proxy/hooks/rate_limit/concurrent/utils.rs | 1 + src/proxy/hooks/rate_limit/ratelimit/local.rs | 1 + src/proxy/hooks/rate_limit/ratelimit/mod.rs | 1 + src/proxy/hooks/rate_limit/ratelimit/utils.rs | 1 + src/proxy/provider.rs | 1 + src/utils/instance.rs | 1 + 40 files changed, 56 insertions(+), 15 deletions(-) diff --git a/src/config/entities/apikeys.rs b/src/config/entities/apikeys.rs index a8d6e85..8a2ce4c 100644 --- a/src/config/entities/apikeys.rs +++ b/src/config/entities/apikeys.rs @@ -75,6 +75,7 @@ impl ApiKeysStore { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use serde_json::json; use super::{SCHEMA, SCHEMA_VALIDATOR, format_evaluation_error}; diff --git a/src/config/entities/mod.rs b/src/config/entities/mod.rs index 7b26f3b..76754fe 100644 --- a/src/config/entities/mod.rs +++ b/src/config/entities/mod.rs @@ -448,6 +448,7 @@ impl EntityStore { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use std::{sync::Mutex, time::Duration}; use anyhow::Result; diff --git a/src/config/entities/models.rs b/src/config/entities/models.rs index 5e984ac..0f91d1b 100644 --- a/src/config/entities/models.rs +++ b/src/config/entities/models.rs @@ -93,6 +93,7 @@ impl ModelsStore { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use serde_json::json; use super::{SCHEMA, SCHEMA_VALIDATOR, format_evaluation_error}; diff --git a/src/config/entities/providers.rs b/src/config/entities/providers.rs index 4f0a9e5..a2cd924 100644 --- a/src/config/entities/providers.rs +++ b/src/config/entities/providers.rs @@ -112,6 +112,7 @@ impl ProvidersStore { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use serde_json::json; use super::{SCHEMA, SCHEMA_VALIDATOR, format_evaluation_error}; diff --git a/src/gateway/error.rs b/src/gateway/error.rs index d584263..12d4f7b 100644 --- a/src/gateway/error.rs +++ b/src/gateway/error.rs @@ -88,6 +88,7 @@ pub type Result = std::result::Result; #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use serde_json::json; use super::*; diff --git a/src/gateway/formats/anthropic_messages.rs b/src/gateway/formats/anthropic_messages.rs index 48195ef..c376eb3 100644 --- a/src/gateway/formats/anthropic_messages.rs +++ b/src/gateway/formats/anthropic_messages.rs @@ -1000,6 +1000,7 @@ fn openai_finish_reason_to_anthropic_stream(finish_reason: &str) -> String { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use serde_json::json; use super::AnthropicMessagesFormat; diff --git a/src/gateway/formats/openai/mod.rs b/src/gateway/formats/openai/mod.rs index bde0581..aaa63e0 100644 --- a/src/gateway/formats/openai/mod.rs +++ b/src/gateway/formats/openai/mod.rs @@ -68,6 +68,7 @@ impl ChatFormat for OpenAIChatFormat { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use serde_json::json; use super::OpenAIChatFormat; diff --git a/src/gateway/gateway.rs b/src/gateway/gateway.rs index 3d1e924..a465944 100644 --- a/src/gateway/gateway.rs +++ b/src/gateway/gateway.rs @@ -402,6 +402,7 @@ async fn provider_error(response: reqwest::Response, provider: &str) -> GatewayE #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use std::{ borrow::Cow, sync::{ diff --git a/src/gateway/provider_instance.rs b/src/gateway/provider_instance.rs index c9ac050..a730299 100644 --- a/src/gateway/provider_instance.rs +++ b/src/gateway/provider_instance.rs @@ -159,6 +159,7 @@ impl ProviderRegistryBuilder { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use std::{borrow::Cow, sync::Arc}; use http::{ diff --git a/src/gateway/providers/anthropic/mod.rs b/src/gateway/providers/anthropic/mod.rs index 9b56932..4a0402c 100644 --- a/src/gateway/providers/anthropic/mod.rs +++ b/src/gateway/providers/anthropic/mod.rs @@ -120,6 +120,7 @@ impl NativeAnthropicMessagesSupport for AnthropicDef { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use serde_json::json; use super::AnthropicDef; diff --git a/src/gateway/providers/anthropic/transform.rs b/src/gateway/providers/anthropic/transform.rs index 171995b..da9163f 100644 --- a/src/gateway/providers/anthropic/transform.rs +++ b/src/gateway/providers/anthropic/transform.rs @@ -743,6 +743,7 @@ fn now_unix_secs() -> u64 { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use serde_json::json; use super::{ diff --git a/src/gateway/providers/bedrock/transform.rs b/src/gateway/providers/bedrock/transform.rs index 253b2f1..4f832ee 100644 --- a/src/gateway/providers/bedrock/transform.rs +++ b/src/gateway/providers/bedrock/transform.rs @@ -601,6 +601,7 @@ fn current_unix_timestamp() -> Result { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use serde_json::json; use super::{ diff --git a/src/gateway/providers/groq.rs b/src/gateway/providers/groq.rs index a3b3ca2..4e35b39 100644 --- a/src/gateway/providers/groq.rs +++ b/src/gateway/providers/groq.rs @@ -78,6 +78,7 @@ impl ProviderCapabilities for Groq {} #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use serde_json::json; use super::Groq; @@ -90,10 +91,10 @@ mod tests { fn provider_metadata_and_url_are_correct() { let provider = Groq; - pretty_assertions::assert_eq!(provider.name(), "groq"); - pretty_assertions::assert_eq!(provider.default_base_url(), "https://api.groq.com/openai"); + assert_eq!(provider.name(), "groq"); + assert_eq!(provider.default_base_url(), "https://api.groq.com/openai"); - pretty_assertions::assert_eq!( + assert_eq!( provider.build_url(provider.default_base_url(), "ignored"), "https://api.groq.com/openai/v1/chat/completions" ); diff --git a/src/gateway/providers/macros.rs b/src/gateway/providers/macros.rs index 009280a..b5206e2 100644 --- a/src/gateway/providers/macros.rs +++ b/src/gateway/providers/macros.rs @@ -147,6 +147,7 @@ pub(crate) use provider; #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use std::borrow::Cow; use crate::gateway::{ diff --git a/src/gateway/providers/mistral.rs b/src/gateway/providers/mistral.rs index 3f472d6..a7d3045 100644 --- a/src/gateway/providers/mistral.rs +++ b/src/gateway/providers/mistral.rs @@ -25,6 +25,8 @@ provider!(Mistral { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; + use super::Mistral; use crate::gateway::traits::{ChatTransform, ProviderMeta}; @@ -32,14 +34,14 @@ mod tests { fn provider_macro_expands_correctly() { let provider = Mistral; - pretty_assertions::assert_eq!(provider.name(), "mistral"); - pretty_assertions::assert_eq!(provider.default_base_url(), "https://api.mistral.ai"); + assert_eq!(provider.name(), "mistral"); + assert_eq!(provider.default_base_url(), "https://api.mistral.ai"); - pretty_assertions::assert_eq!( + assert_eq!( provider.build_url(provider.default_base_url(), "ignored"), "https://api.mistral.ai/v1/chat/completions" ); - pretty_assertions::assert_eq!(provider.default_quirks().tool_args_may_be_object, true); + assert!(provider.default_quirks().tool_args_may_be_object); } } diff --git a/src/gateway/providers/mod.rs b/src/gateway/providers/mod.rs index f5a0d96..b093b8e 100644 --- a/src/gateway/providers/mod.rs +++ b/src/gateway/providers/mod.rs @@ -60,6 +60,7 @@ pub fn default_provider_registry() -> Result { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use super::default_provider_registry; #[test] diff --git a/src/gateway/providers/openai.rs b/src/gateway/providers/openai.rs index f6f2320..e5952e5 100644 --- a/src/gateway/providers/openai.rs +++ b/src/gateway/providers/openai.rs @@ -56,6 +56,7 @@ impl ProviderCapabilities for OpenAIDef { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use serde_json::json; use super::OpenAIDef; diff --git a/src/gateway/providers/openrouter.rs b/src/gateway/providers/openrouter.rs index a6d9358..cab920b 100644 --- a/src/gateway/providers/openrouter.rs +++ b/src/gateway/providers/openrouter.rs @@ -23,6 +23,7 @@ provider!(OpenRouter { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use super::OpenRouter; use crate::gateway::traits::ProviderMeta; @@ -30,11 +31,11 @@ mod tests { fn provider_macro_expands_correctly() { let provider = OpenRouter; - pretty_assertions::assert_eq!(provider.name(), "openrouter"); - pretty_assertions::assert_eq!(provider.default_base_url(), "https://openrouter.ai/api/v1"); - pretty_assertions::assert_eq!(provider.chat_endpoint_path("ignored"), "/chat/completions"); + assert_eq!(provider.name(), "openrouter"); + assert_eq!(provider.default_base_url(), "https://openrouter.ai/api/v1"); + assert_eq!(provider.chat_endpoint_path("ignored"), "/chat/completions"); - pretty_assertions::assert_eq!( + assert_eq!( provider.build_url(provider.default_base_url(), "ignored"), "https://openrouter.ai/api/v1/chat/completions" ); diff --git a/src/gateway/streams/bridged.rs b/src/gateway/streams/bridged.rs index 1d45a2d..b0af9b7 100644 --- a/src/gateway/streams/bridged.rs +++ b/src/gateway/streams/bridged.rs @@ -138,6 +138,7 @@ impl PinnedDrop for BridgedStream { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use std::sync::Arc; use futures::StreamExt; diff --git a/src/gateway/streams/hub.rs b/src/gateway/streams/hub.rs index d500740..23234a8 100644 --- a/src/gateway/streams/hub.rs +++ b/src/gateway/streams/hub.rs @@ -107,6 +107,7 @@ impl Stream for HubChunkStream { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use std::sync::Arc; use futures::StreamExt; diff --git a/src/gateway/streams/native.rs b/src/gateway/streams/native.rs index 813494d..dab2de4 100644 --- a/src/gateway/streams/native.rs +++ b/src/gateway/streams/native.rs @@ -119,6 +119,7 @@ impl PinnedDrop for NativeStream { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use std::sync::Arc; use futures::StreamExt; diff --git a/src/gateway/streams/reader/aws_event_stream.rs b/src/gateway/streams/reader/aws_event_stream.rs index e23fb01..1c50c5c 100644 --- a/src/gateway/streams/reader/aws_event_stream.rs +++ b/src/gateway/streams/reader/aws_event_stream.rs @@ -180,6 +180,7 @@ fn build_aws_event_stream_exception_message(exception_type: &str, payload: &[u8] #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use aws_smithy_eventstream::frame::write_message_to; use aws_smithy_types::event_stream::{Header, HeaderValue, Message}; use bytes::Bytes; diff --git a/src/gateway/streams/reader/sse.rs b/src/gateway/streams/reader/sse.rs index ce730b5..2d5e1e3 100644 --- a/src/gateway/streams/reader/sse.rs +++ b/src/gateway/streams/reader/sse.rs @@ -71,6 +71,7 @@ where #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use bytes::Bytes; use futures::StreamExt; diff --git a/src/gateway/traits/chat_format.rs b/src/gateway/traits/chat_format.rs index 4e17c03..d69d43e 100644 --- a/src/gateway/traits/chat_format.rs +++ b/src/gateway/traits/chat_format.rs @@ -145,6 +145,7 @@ pub struct ChatStreamState { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use std::borrow::Cow; use http::HeaderMap; diff --git a/src/gateway/traits/provider.rs b/src/gateway/traits/provider.rs index 96e56a3..bbcc738 100644 --- a/src/gateway/traits/provider.rs +++ b/src/gateway/traits/provider.rs @@ -319,6 +319,7 @@ pub trait ImageGenTransform: Send + Sync + 'static {} #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use std::borrow::Cow; use http::HeaderMap; @@ -483,8 +484,8 @@ mod tests { let body = provider.transform_embeddings_request(&request).unwrap(); match body { EmbedRequestBody::Json(value) => { - pretty_assertions::assert_eq!(value["model"], "text-embedding-3-large"); - pretty_assertions::assert_eq!(value["input"], "hello"); + assert_eq!(value["model"], "text-embedding-3-large"); + assert_eq!(value["input"], "hello"); } } @@ -501,14 +502,14 @@ mod tests { }))) .unwrap(); - pretty_assertions::assert_eq!(response.data.len(), 1); + assert_eq!(response.data.len(), 1); let usage = match response.usage { Some(usage) => usage, None => panic!("expected usage in embedding response"), }; - pretty_assertions::assert_eq!(usage.total_tokens, 2); + assert_eq!(usage.total_tokens, 2); } #[test] diff --git a/src/gateway/types/anthropic.rs b/src/gateway/types/anthropic.rs index f358d42..3044795 100644 --- a/src/gateway/types/anthropic.rs +++ b/src/gateway/types/anthropic.rs @@ -325,6 +325,7 @@ pub struct AnthropicError { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use serde_json::json; use super::*; diff --git a/src/gateway/types/common.rs b/src/gateway/types/common.rs index 6ba6a3b..7bb5865 100644 --- a/src/gateway/types/common.rs +++ b/src/gateway/types/common.rs @@ -158,6 +158,7 @@ pub struct OpenAIResponsesExtras { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use super::*; #[test] diff --git a/src/gateway/types/embed.rs b/src/gateway/types/embed.rs index d0fe611..a75ae19 100644 --- a/src/gateway/types/embed.rs +++ b/src/gateway/types/embed.rs @@ -60,6 +60,7 @@ pub struct EmbeddingResponse { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use serde_json::json; use super::{EmbeddingRequest, EmbeddingResponse, OneOrMany}; diff --git a/src/gateway/types/openai/mod.rs b/src/gateway/types/openai/mod.rs index 1413500..048001b 100644 --- a/src/gateway/types/openai/mod.rs +++ b/src/gateway/types/openai/mod.rs @@ -321,6 +321,7 @@ pub struct ChunkFunctionCall { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use serde_json::json; use super::*; diff --git a/src/gateway/types/openai/responses.rs b/src/gateway/types/openai/responses.rs index f6a3093..960beb3 100644 --- a/src/gateway/types/openai/responses.rs +++ b/src/gateway/types/openai/responses.rs @@ -406,6 +406,7 @@ impl ResponsesApiStreamEvent { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use serde_json::json; use super::*; diff --git a/src/proxy/handlers/chat_completions/span_attributes/tests.rs b/src/proxy/handlers/chat_completions/span_attributes/tests.rs index f3270c5..b6268fc 100644 --- a/src/proxy/handlers/chat_completions/span_attributes/tests.rs +++ b/src/proxy/handlers/chat_completions/span_attributes/tests.rs @@ -1,3 +1,4 @@ +use pretty_assertions::assert_eq; use serde_json::{Value, json}; use super::{StreamOutputCollector, request_span_properties, response_span_properties}; diff --git a/src/proxy/handlers/messages/types.rs b/src/proxy/handlers/messages/types.rs index e02b155..2602220 100644 --- a/src/proxy/handlers/messages/types.rs +++ b/src/proxy/handlers/messages/types.rs @@ -181,6 +181,7 @@ fn gateway_error_type(error: &GatewayError) -> &'static str { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use http::StatusCode; use serde_json::json; diff --git a/src/proxy/hooks/rate_limit/concurrent/local.rs b/src/proxy/hooks/rate_limit/concurrent/local.rs index 69c2f46..e795eed 100644 --- a/src/proxy/hooks/rate_limit/concurrent/local.rs +++ b/src/proxy/hooks/rate_limit/concurrent/local.rs @@ -111,6 +111,7 @@ impl ConcurrencyLimiter for LocalConcurrencyLimiter { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use super::*; #[tokio::test] diff --git a/src/proxy/hooks/rate_limit/concurrent/mod.rs b/src/proxy/hooks/rate_limit/concurrent/mod.rs index 98d9ec1..43fe736 100644 --- a/src/proxy/hooks/rate_limit/concurrent/mod.rs +++ b/src/proxy/hooks/rate_limit/concurrent/mod.rs @@ -82,6 +82,7 @@ pub fn get_concurrency_limiter() -> Arc { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use super::*; #[test] diff --git a/src/proxy/hooks/rate_limit/concurrent/utils.rs b/src/proxy/hooks/rate_limit/concurrent/utils.rs index 620eed0..08dbebd 100644 --- a/src/proxy/hooks/rate_limit/concurrent/utils.rs +++ b/src/proxy/hooks/rate_limit/concurrent/utils.rs @@ -114,6 +114,7 @@ impl IntoResponse for ConcurrencyLimitResponse { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use super::*; #[test] diff --git a/src/proxy/hooks/rate_limit/ratelimit/local.rs b/src/proxy/hooks/rate_limit/ratelimit/local.rs index 12a11f0..5184f38 100644 --- a/src/proxy/hooks/rate_limit/ratelimit/local.rs +++ b/src/proxy/hooks/rate_limit/ratelimit/local.rs @@ -63,6 +63,7 @@ impl RateLimiter for LocalRateLimiter { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use http::HeaderMap; use super::*; diff --git a/src/proxy/hooks/rate_limit/ratelimit/mod.rs b/src/proxy/hooks/rate_limit/ratelimit/mod.rs index 559e5a8..df8d7f5 100644 --- a/src/proxy/hooks/rate_limit/ratelimit/mod.rs +++ b/src/proxy/hooks/rate_limit/ratelimit/mod.rs @@ -99,6 +99,7 @@ pub fn get_rate_limiter() -> Arc { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use super::*; use crate::config::entities::types::RateLimit; diff --git a/src/proxy/hooks/rate_limit/ratelimit/utils.rs b/src/proxy/hooks/rate_limit/ratelimit/utils.rs index 14113d3..5446a10 100644 --- a/src/proxy/hooks/rate_limit/ratelimit/utils.rs +++ b/src/proxy/hooks/rate_limit/ratelimit/utils.rs @@ -255,6 +255,7 @@ impl IntoResponse for RateLimitResponse { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use std::time::Duration; use super::*; diff --git a/src/proxy/provider.rs b/src/proxy/provider.rs index 7a10b41..77a82bd 100644 --- a/src/proxy/provider.rs +++ b/src/proxy/provider.rs @@ -145,6 +145,7 @@ fn default_bedrock_base_url(region: &str) -> String { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use reqwest::Url; use super::provider_auth_and_base_url; diff --git a/src/utils/instance.rs b/src/utils/instance.rs index b972ceb..afddd77 100644 --- a/src/utils/instance.rs +++ b/src/utils/instance.rs @@ -97,6 +97,7 @@ fn write_id_file(path: &Path, id: &str) -> Result<()> { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use std::path::PathBuf; use super::*; From 7973e445d7a63ae75a0c852543bf0fef65c62a5c Mon Sep 17 00:00:00 2001 From: bzp2010 Date: Sun, 3 May 2026 13:57:44 +0800 Subject: [PATCH 2/2] fix comments --- src/config/entities/providers.rs | 6 +- src/gateway/formats/anthropic_messages.rs | 57 ++++++------- src/gateway/formats/openai/mod.rs | 5 +- src/gateway/gateway.rs | 83 ++++++++++--------- src/gateway/provider_instance.rs | 21 ++--- src/gateway/providers/anthropic/mod.rs | 3 +- src/gateway/providers/anthropic/transform.rs | 19 +++-- src/gateway/providers/bedrock/transform.rs | 5 +- src/gateway/providers/macros.rs | 5 +- src/gateway/providers/openai.rs | 5 +- .../streams/reader/aws_event_stream.rs | 5 +- src/gateway/streams/reader/sse.rs | 5 +- src/gateway/traits/chat_format.rs | 9 +- src/gateway/types/anthropic.rs | 5 +- src/gateway/types/embed.rs | 5 +- src/gateway/types/openai/mod.rs | 11 +-- src/gateway/types/openai/responses.rs | 13 +-- .../hooks/rate_limit/concurrent/local.rs | 10 ++- src/proxy/hooks/rate_limit/ratelimit/local.rs | 9 +- src/proxy/provider.rs | 5 +- 20 files changed, 152 insertions(+), 134 deletions(-) diff --git a/src/config/entities/providers.rs b/src/config/entities/providers.rs index a2cd924..067af22 100644 --- a/src/config/entities/providers.rs +++ b/src/config/entities/providers.rs @@ -112,6 +112,7 @@ impl ProvidersStore { #[cfg(test)] mod tests { + use assert_matches::assert_matches; use pretty_assertions::assert_eq; use serde_json::json; @@ -213,9 +214,6 @@ mod tests { assert_eq!(provider.name, "openai-primary"); assert_eq!(provider.provider_type(), "openai"); - assert!(matches!( - provider.provider, - super::ProviderConfig::OpenAI(_) - )); + assert_matches!(provider.provider, super::ProviderConfig::OpenAI(_)); } } diff --git a/src/gateway/formats/anthropic_messages.rs b/src/gateway/formats/anthropic_messages.rs index c376eb3..8ad1520 100644 --- a/src/gateway/formats/anthropic_messages.rs +++ b/src/gateway/formats/anthropic_messages.rs @@ -1000,6 +1000,7 @@ fn openai_finish_reason_to_anthropic_stream(finish_reason: &str) -> String { #[cfg(test)] mod tests { + use assert_matches::assert_matches; use pretty_assertions::assert_eq; use serde_json::json; @@ -1149,24 +1150,24 @@ mod tests { &BridgeContext::default(), ) .unwrap(); - assert!(matches!( + assert_matches!( &events[0], crate::gateway::types::anthropic::AnthropicStreamEvent::MessageStart { message } if message.id == "chatcmpl-123" && message.usage.input_tokens.is_none() && message.usage.output_tokens.is_none() - )); - assert!(matches!( + ); + assert_matches!( &events[1], crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockStart { index, .. } if *index == 0 - )); - assert!(matches!( + ); + assert_matches!( &events[2], crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockDelta { index, delta } if *index == 0 && matches!(delta, crate::gateway::types::anthropic::ContentDelta::TextDelta { text } if text == "hello") - )); + ); assert!( AnthropicMessagesFormat::from_hub_stream( @@ -1180,12 +1181,12 @@ mod tests { let end_events = AnthropicMessagesFormat::stream_end_events(&mut state, &BridgeContext::default()); - assert!(matches!( + assert_matches!( &end_events[0], crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockStop { index } if *index == 0 - )); - assert!(matches!( + ); + assert_matches!( &end_events[1], crate::gateway::types::anthropic::AnthropicStreamEvent::MessageDelta { delta, usage } if delta.stop_reason.is_none() @@ -1193,11 +1194,11 @@ mod tests { && usage.output_tokens == Some(9) && usage.cache_creation_input_tokens == Some(0) && usage.cache_read_input_tokens == Some(2) - )); - assert!(matches!( + ); + assert_matches!( &end_events[2], crate::gateway::types::anthropic::AnthropicStreamEvent::MessageStop - )); + ); } #[test] @@ -1243,16 +1244,16 @@ mod tests { assert_eq!(bridged.usage.output_tokens, 7); assert_eq!(bridged.usage.cache_read_input_tokens, 2); assert!(bridged.usage.cache_creation.is_none()); - assert!(matches!( + assert_matches!( &bridged.content[0], crate::gateway::types::anthropic::AnthropicContentBlock::Text { text, .. } if text == "Calling tool" - )); - assert!(matches!( + ); + assert_matches!( &bridged.content[1], crate::gateway::types::anthropic::AnthropicContentBlock::ToolUse { name, .. } if name == "get_weather" - )); + ); } #[test] @@ -1309,18 +1310,18 @@ mod tests { &BridgeContext::default(), ) .unwrap(); - assert!(matches!( + assert_matches!( &first_events[1], crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockStart { index, content_block } if *index == 0 && matches!(content_block, crate::gateway::types::anthropic::AnthropicContentBlock::ToolUse { name, .. } if name == "get_weather") - )); - assert!(matches!( + ); + assert_matches!( &first_events[2], crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockDelta { index, delta } if *index == 0 && matches!(delta, crate::gateway::types::anthropic::ContentDelta::InputJsonDelta { partial_json } if partial_json == "{\"city\"") - )); + ); let second_events = AnthropicMessagesFormat::from_hub_stream( &second_chunk, @@ -1328,20 +1329,20 @@ mod tests { &BridgeContext::default(), ) .unwrap(); - assert!(matches!( + assert_matches!( &second_events[0], crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockDelta { index, delta } if *index == 0 && matches!(delta, crate::gateway::types::anthropic::ContentDelta::InputJsonDelta { partial_json } if partial_json == ":\"SF\"}") - )); + ); let end_events = AnthropicMessagesFormat::stream_end_events(&mut state, &BridgeContext::default()); - assert!(matches!( + assert_matches!( &end_events[1], crate::gateway::types::anthropic::AnthropicStreamEvent::MessageDelta { delta, .. } if delta.stop_reason.as_deref() == Some("tool_use") - )); + ); } #[test] @@ -1375,11 +1376,11 @@ mod tests { &mut missing_type_state, &BridgeContext::default(), ); - assert!(matches!( + assert_matches!( missing_type_result, Err(GatewayError::Bridge(message)) if message.contains("requires tool call types") - )); + ); let mut invalid_type_state = super::AnthropicBridgeState::default(); let invalid_type_chunk: crate::gateway::types::openai::ChatCompletionChunk = @@ -1411,11 +1412,11 @@ mod tests { &mut invalid_type_state, &BridgeContext::default(), ); - assert!(matches!( + assert_matches!( invalid_type_result, Err(GatewayError::Bridge(message)) if message.contains("only supports function tool calls") - )); + ); } #[test] diff --git a/src/gateway/formats/openai/mod.rs b/src/gateway/formats/openai/mod.rs index aaa63e0..52b9bd8 100644 --- a/src/gateway/formats/openai/mod.rs +++ b/src/gateway/formats/openai/mod.rs @@ -68,6 +68,7 @@ impl ChatFormat for OpenAIChatFormat { #[cfg(test)] mod tests { + use assert_matches::assert_matches; use pretty_assertions::assert_eq; use serde_json::json; @@ -225,10 +226,10 @@ mod tests { let error = OpenAIChatFormat::transform_native_stream_chunk(&provider, "data: {}", &mut ()) .unwrap_err(); - assert!(matches!( + assert_matches!( error, GatewayError::NativeNotSupported { provider } if provider == "dummy" - )); + ); assert!(OpenAIChatFormat::native_support(&provider).is_none()); } } diff --git a/src/gateway/gateway.rs b/src/gateway/gateway.rs index a465944..7c176ef 100644 --- a/src/gateway/gateway.rs +++ b/src/gateway/gateway.rs @@ -402,7 +402,6 @@ async fn provider_error(response: reqwest::Response, provider: &str) -> GatewayE #[cfg(test)] mod tests { - use pretty_assertions::assert_eq; use std::{ borrow::Cow, sync::{ @@ -411,6 +410,7 @@ mod tests { }, }; + use assert_matches::assert_matches; use aws_smithy_eventstream::frame::write_message_to; use aws_smithy_types::event_stream::{Header, HeaderValue as EventStreamHeaderValue, Message}; use axum::{Json, Router, extract::OriginalUri, routing::post}; @@ -420,6 +420,7 @@ mod tests { HeaderMap, HeaderValue, StatusCode, header::{AUTHORIZATION, CONTENT_TYPE, HeaderName}, }; + use pretty_assertions::assert_eq; use reqwest::Url; use serde_json::{Value, json}; use tokio::{net::TcpListener, sync::Mutex, task::JoinHandle}; @@ -908,11 +909,11 @@ mod tests { }; assert_eq!(response.model, "gpt-test"); - assert!(matches!( + assert_matches!( response.choices[0].message.content.as_ref(), Some(crate::gateway::types::openai::MessageContent::Text(text)) if text == "hello from hub" - )); + ); assert_eq!(usage.input_tokens, Some(7)); assert_eq!(usage.output_tokens, Some(9)); assert_eq!(usage.total_tokens, Some(16)); @@ -1065,11 +1066,11 @@ mod tests { ); assert!(response.id.starts_with("bedrock-")); assert_eq!(response.choices[0].finish_reason.as_deref(), Some("stop")); - assert!(matches!( + assert_matches!( response.choices[0].message.content.as_ref(), Some(crate::gateway::types::openai::MessageContent::Text(text)) if text == "hello from bedrock" - )); + ); assert_eq!(usage.total_tokens, Some(16)); let observed = observed.lock().await.take().unwrap(); @@ -1166,11 +1167,11 @@ mod tests { let error = gateway.embed(&request, &instance).await.unwrap_err(); - assert!(matches!( + assert_matches!( error, GatewayError::EmbeddingsNotSupported { provider } if provider == "native-test" - )); + ); } #[tokio::test] @@ -1297,10 +1298,10 @@ mod tests { assert_eq!(response.usage.input_tokens, 5); assert_eq!(response.usage.output_tokens, 9); assert_eq!(response.usage.cache_read_input_tokens, 2); - assert!(matches!( + assert_matches!( &response.content[0], AnthropicContentBlock::Text { text, .. } if text == "hello from hub" - )); + ); assert_eq!(usage.input_tokens, Some(7)); assert_eq!(usage.output_tokens, Some(9)); assert_eq!(usage.total_tokens, Some(16)); @@ -1396,39 +1397,39 @@ mod tests { let message_stop = stream.next().await.unwrap().unwrap(); assert!(stream.next().await.is_none()); - assert!(matches!( + assert_matches!( message_start, crate::gateway::types::anthropic::AnthropicStreamEvent::MessageStart { message } if message.id == "chatcmpl-789" - )); - assert!(matches!( + ); + assert_matches!( block_start, crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockStart { index, .. } if index == 0 - )); - assert!(matches!( + ); + assert_matches!( block_delta, crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockDelta { index, delta } if index == 0 && matches!(&delta, crate::gateway::types::anthropic::ContentDelta::TextDelta { text } if text == "hello") - )); - assert!(matches!( + ); + assert_matches!( block_stop, crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockStop { index } if index == 0 - )); - assert!(matches!( + ); + assert_matches!( message_delta, crate::gateway::types::anthropic::AnthropicStreamEvent::MessageDelta { usage, .. } if usage.input_tokens == Some(5) && usage.output_tokens == Some(9) && usage.cache_creation_input_tokens == Some(0) && usage.cache_read_input_tokens == Some(2) - )); - assert!(matches!( + ); + assert_matches!( message_stop, crate::gateway::types::anthropic::AnthropicStreamEvent::MessageStop - )); + ); let usage = usage_rx.await.unwrap(); assert_eq!(usage.input_tokens, Some(7)); @@ -1496,10 +1497,10 @@ mod tests { assert_eq!(response.id, "msg_123"); assert_eq!(response.model, "claude-3-5-sonnet-20241022"); - assert!(matches!( + assert_matches!( &response.content[0], AnthropicContentBlock::Text { text, .. } if text == "hello from native" - )); + ); assert_eq!(usage.input_tokens, Some(10)); assert_eq!(usage.output_tokens, Some(4)); assert_eq!(usage.total_tokens, Some(14)); @@ -1574,42 +1575,42 @@ mod tests { let message_stop = stream.next().await.unwrap().unwrap(); assert!(stream.next().await.is_none()); - assert!(matches!( + assert_matches!( message_start, crate::gateway::types::anthropic::AnthropicStreamEvent::MessageStart { message } if message.usage.input_tokens == Some(3) && message.usage.output_tokens == Some(1) && message.usage.cache_creation_input_tokens == Some(5) && message.usage.cache_read_input_tokens == Some(2) - )); - assert!(matches!( + ); + assert_matches!( block_start, crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockStart { index, .. } if index == 0 - )); - assert!(matches!( + ); + assert_matches!( block_delta, crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockDelta { index, delta } if index == 0 && matches!(&delta, crate::gateway::types::anthropic::ContentDelta::TextDelta { text } if text == "hello") - )); - assert!(matches!( + ); + assert_matches!( block_stop, crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockStop { index } if index == 0 - )); - assert!(matches!( + ); + assert_matches!( message_delta, crate::gateway::types::anthropic::AnthropicStreamEvent::MessageDelta { usage, .. } if usage.input_tokens == Some(3) && usage.output_tokens == Some(4) && usage.cache_creation_input_tokens.is_none() && usage.cache_read_input_tokens.is_none() - )); - assert!(matches!( + ); + assert_matches!( message_stop, crate::gateway::types::anthropic::AnthropicStreamEvent::MessageStop - )); + ); let usage = usage_rx.await.unwrap(); assert_eq!(usage.input_tokens, Some(10)); @@ -1939,11 +1940,11 @@ mod tests { .unwrap(); let result = gateway.chat_completion(&request, &instance).await; - assert!(matches!( - result, + assert_matches!( + result.map(|_| ()), Err(GatewayError::Validation(message)) if message.contains("JsonArrayStream") - )); + ); assert_eq!(request_count.load(Ordering::SeqCst), 0); server.abort(); @@ -1984,11 +1985,11 @@ mod tests { let result = gateway .chat::(&request, &instance) .await; - assert!(matches!( - result, + assert_matches!( + result.map(|_| ()), Err(GatewayError::Validation(message)) if message.contains("JsonArrayStream") - )); + ); assert_eq!(request_count.load(Ordering::SeqCst), 0); server.abort(); diff --git a/src/gateway/provider_instance.rs b/src/gateway/provider_instance.rs index a730299..bda5d5c 100644 --- a/src/gateway/provider_instance.rs +++ b/src/gateway/provider_instance.rs @@ -159,6 +159,7 @@ impl ProviderRegistryBuilder { #[cfg(test)] mod tests { + use assert_matches::assert_matches; use pretty_assertions::assert_eq; use std::{borrow::Cow, sync::Arc}; @@ -284,23 +285,23 @@ mod tests { ); let error = ProviderAuth::None.api_key().unwrap_err(); - assert!(matches!( + assert_matches!( error, GatewayError::Validation(message) if message.contains("ProviderAuth::ApiKey") - )); + ); } #[test] fn provider_auth_api_key_for_adds_provider_context() { let error = ProviderAuth::None.api_key_for("deepseek").unwrap_err(); - assert!(matches!( + assert_matches!( error, GatewayError::Validation(message) if message.contains("deepseek") && message.contains("ProviderAuth::ApiKey") - )); + ); } #[test] @@ -318,12 +319,12 @@ mod tests { let error = ProviderAuth::None .aws_static_credentials_for("bedrock") .unwrap_err(); - assert!(matches!( + assert_matches!( error, GatewayError::Validation(message) if message.contains("bedrock") && message.contains("ProviderAuth::AwsStatic") - )); + ); } #[test] @@ -352,12 +353,12 @@ mod tests { let error = instance.effective_base_url().unwrap_err(); - assert!(matches!( + assert_matches!( error, GatewayError::Validation(message) if message.contains("invalid-url") && message.contains("default_base_url") - )); + ); } #[test] @@ -401,11 +402,11 @@ mod tests { .err() .unwrap(); - assert!(matches!( + assert_matches!( error, GatewayError::Validation(message) if message.contains("dummy") && message.contains("already registered") - )); + ); } } diff --git a/src/gateway/providers/anthropic/mod.rs b/src/gateway/providers/anthropic/mod.rs index 4a0402c..b5e78da 100644 --- a/src/gateway/providers/anthropic/mod.rs +++ b/src/gateway/providers/anthropic/mod.rs @@ -120,6 +120,7 @@ impl NativeAnthropicMessagesSupport for AnthropicDef { #[cfg(test)] mod tests { + use assert_matches::assert_matches; use pretty_assertions::assert_eq; use serde_json::json; @@ -182,7 +183,7 @@ mod tests { assert_eq!(body["model"], "claude-3-5-sonnet-20241022"); assert_eq!(parsed.id, "msg_123"); - assert!(matches!(events.as_slice(), [AnthropicStreamEvent::Ping])); + assert_matches!(events.as_slice(), [AnthropicStreamEvent::Ping]); } #[test] diff --git a/src/gateway/providers/anthropic/transform.rs b/src/gateway/providers/anthropic/transform.rs index da9163f..d5c3a1b 100644 --- a/src/gateway/providers/anthropic/transform.rs +++ b/src/gateway/providers/anthropic/transform.rs @@ -743,6 +743,7 @@ fn now_unix_secs() -> u64 { #[cfg(test)] mod tests { + use assert_matches::assert_matches; use pretty_assertions::assert_eq; use serde_json::json; @@ -790,10 +791,10 @@ mod tests { assert_eq!(anthropic.messages.len(), 1); assert_eq!(anthropic.messages[0].role, "user"); assert_eq!(anthropic.tools.as_ref().unwrap()[0].name, "get_weather"); - assert!(matches!( + assert_matches!( anthropic.tool_choice, Some(crate::gateway::types::anthropic::AnthropicToolChoice::Auto) - )); + ); assert_eq!( anthropic.metadata.as_ref().unwrap().user_id.as_deref(), Some("user-123") @@ -830,10 +831,10 @@ mod tests { openai.choices[0].finish_reason.as_deref(), Some("tool_calls") ); - assert!(matches!( + assert_matches!( openai.choices[0].message.content.as_ref(), Some(crate::gateway::types::openai::MessageContent::Text(text)) if text == "Let me check." - )); + ); let tool_call = &openai.choices[0].message.tool_calls.as_ref().unwrap()[0]; assert_eq!(tool_call.id, "tu_1"); assert_eq!(tool_call.function.name, "get_weather"); @@ -926,7 +927,7 @@ mod tests { ); let events = parse_anthropic_native_sse(r#"data: {"type":"ping"}"#).unwrap(); - assert!(matches!(events.as_slice(), [AnthropicStreamEvent::Ping])); + assert_matches!(events.as_slice(), [AnthropicStreamEvent::Ping]); } #[test] @@ -939,12 +940,12 @@ mod tests { .unwrap(); let error = openai_to_anthropic_request(&request).unwrap_err(); - assert!(matches!( + assert_matches!( error, crate::gateway::error::GatewayError::Bridge(message) if message.contains("n=1") && message.contains('2') - )); + ); } #[test] @@ -957,10 +958,10 @@ mod tests { .unwrap(); let error = openai_to_anthropic_request(&request).unwrap_err(); - assert!(matches!( + assert_matches!( error, crate::gateway::error::GatewayError::Bridge(message) if message.contains("requires tools") - )); + ); } } diff --git a/src/gateway/providers/bedrock/transform.rs b/src/gateway/providers/bedrock/transform.rs index 4f832ee..682149e 100644 --- a/src/gateway/providers/bedrock/transform.rs +++ b/src/gateway/providers/bedrock/transform.rs @@ -601,6 +601,7 @@ fn current_unix_timestamp() -> Result { #[cfg(test)] mod tests { + use assert_matches::assert_matches; use pretty_assertions::assert_eq; use serde_json::json; @@ -699,10 +700,10 @@ mod tests { mapped.choices[0].finish_reason.as_deref(), Some("tool_calls") ); - assert!(matches!( + assert_matches!( mapped.choices[0].message.content.as_ref(), Some(MessageContent::Text(text)) if text == "Need a tool." - )); + ); assert_eq!( mapped.choices[0].message.tool_calls.as_ref().unwrap()[0] .function diff --git a/src/gateway/providers/macros.rs b/src/gateway/providers/macros.rs index b5206e2..e6cf80c 100644 --- a/src/gateway/providers/macros.rs +++ b/src/gateway/providers/macros.rs @@ -147,6 +147,7 @@ pub(crate) use provider; #[cfg(test)] mod tests { + use assert_matches::assert_matches; use pretty_assertions::assert_eq; use std::borrow::Cow; @@ -203,11 +204,11 @@ mod tests { .build_auth_headers(&ProviderAuth::None) .unwrap_err(); - assert!(matches!( + assert_matches!( error, crate::gateway::error::GatewayError::Validation(message) if message.contains("macro-test") && message.contains("ProviderAuth::ApiKey") - )); + ); } } diff --git a/src/gateway/providers/openai.rs b/src/gateway/providers/openai.rs index e5952e5..ec5b5d5 100644 --- a/src/gateway/providers/openai.rs +++ b/src/gateway/providers/openai.rs @@ -56,6 +56,7 @@ impl ProviderCapabilities for OpenAIDef { #[cfg(test)] mod tests { + use assert_matches::assert_matches; use pretty_assertions::assert_eq; use serde_json::json; @@ -103,12 +104,12 @@ mod tests { .build_auth_headers(&ProviderAuth::None) .unwrap_err(); - assert!(matches!( + assert_matches!( error, crate::gateway::error::GatewayError::Validation(message) if message.contains("openai") && message.contains("ProviderAuth::ApiKey") - )); + ); } #[test] diff --git a/src/gateway/streams/reader/aws_event_stream.rs b/src/gateway/streams/reader/aws_event_stream.rs index 1c50c5c..16df7fe 100644 --- a/src/gateway/streams/reader/aws_event_stream.rs +++ b/src/gateway/streams/reader/aws_event_stream.rs @@ -180,6 +180,7 @@ fn build_aws_event_stream_exception_message(exception_type: &str, payload: &[u8] #[cfg(test)] mod tests { + use assert_matches::assert_matches; use pretty_assertions::assert_eq; use aws_smithy_eventstream::frame::write_message_to; use aws_smithy_types::event_stream::{Header, HeaderValue, Message}; @@ -229,11 +230,11 @@ mod tests { let mut reader = aws_event_stream_reader(byte_stream); - assert!(matches!( + assert_matches!( reader.next().await.unwrap(), Err(GatewayError::Stream(message)) if message.contains("validationException") && message.contains("bad request") - )); + ); assert!(reader.next().await.is_none()); } diff --git a/src/gateway/streams/reader/sse.rs b/src/gateway/streams/reader/sse.rs index 2d5e1e3..3cd83ba 100644 --- a/src/gateway/streams/reader/sse.rs +++ b/src/gateway/streams/reader/sse.rs @@ -71,6 +71,7 @@ where #[cfg(test)] mod tests { + use assert_matches::assert_matches; use pretty_assertions::assert_eq; use bytes::Bytes; use futures::StreamExt; @@ -117,10 +118,10 @@ mod tests { let mut reader = sse_reader(byte_stream); - assert!(matches!( + assert_matches!( reader.next().await.unwrap(), Err(GatewayError::Http(_)) - )); + ); assert!(reader.next().await.is_none()); } } diff --git a/src/gateway/traits/chat_format.rs b/src/gateway/traits/chat_format.rs index d69d43e..5b181a6 100644 --- a/src/gateway/traits/chat_format.rs +++ b/src/gateway/traits/chat_format.rs @@ -145,6 +145,7 @@ pub struct ChatStreamState { #[cfg(test)] mod tests { + use assert_matches::assert_matches; use pretty_assertions::assert_eq; use std::borrow::Cow; @@ -283,10 +284,10 @@ mod tests { let native = NativeHandler::OpenAIResponses(&provider); let error = DummyFormat::call_native(&native, &json!({}), false).unwrap_err(); - assert!(matches!( + assert_matches!( error, GatewayError::NativeNotSupported { provider } if provider == "dummy-native-provider" - )); + ); } #[test] @@ -295,12 +296,12 @@ mod tests { let native = NativeHandler::OpenAIResponses(&provider); let error = DummyFormat::parse_native_response(&native, json!({})).unwrap_err(); - assert!(matches!( + assert_matches!( error, GatewayError::Bridge(message) if message.contains("parse_native_response called on a non-native format") && message.contains("dummy-native-provider") - )); + ); } #[test] diff --git a/src/gateway/types/anthropic.rs b/src/gateway/types/anthropic.rs index 3044795..4b827a3 100644 --- a/src/gateway/types/anthropic.rs +++ b/src/gateway/types/anthropic.rs @@ -325,6 +325,7 @@ pub struct AnthropicError { #[cfg(test)] mod tests { + use assert_matches::assert_matches; use pretty_assertions::assert_eq; use serde_json::json; @@ -358,7 +359,7 @@ mod tests { "messages": [{"role": "user", "content": "Hi"}] }); let req: AnthropicMessagesRequest = serde_json::from_value(json).unwrap(); - assert!(matches!(req.system, Some(SystemPrompt::Text(ref s)) if s == "You are helpful.")); + assert_matches!(req.system, Some(SystemPrompt::Text(ref s)) if s == "You are helpful."); } #[test] @@ -412,7 +413,7 @@ mod tests { }); let req: AnthropicMessagesRequest = serde_json::from_value(json).unwrap(); assert_eq!(req.tools.as_ref().unwrap().len(), 1); - assert!(matches!(req.tool_choice, Some(AnthropicToolChoice::Auto))); + assert_matches!(req.tool_choice, Some(AnthropicToolChoice::Auto)); } #[test] diff --git a/src/gateway/types/embed.rs b/src/gateway/types/embed.rs index a75ae19..e3aae19 100644 --- a/src/gateway/types/embed.rs +++ b/src/gateway/types/embed.rs @@ -60,6 +60,7 @@ pub struct EmbeddingResponse { #[cfg(test)] mod tests { + use assert_matches::assert_matches; use pretty_assertions::assert_eq; use serde_json::json; @@ -77,7 +78,7 @@ mod tests { .unwrap(); assert_eq!(request.model, "text-embedding-3-large"); - assert!(matches!(request.input, OneOrMany::Many(_))); + assert_matches!(request.input, OneOrMany::Many(_)); let value = serde_json::to_value(&request).unwrap(); assert_eq!(value["dimensions"], 256); @@ -92,7 +93,7 @@ mod tests { })) .unwrap(); - assert!(matches!(request.input, OneOrMany::One(_))); + assert_matches!(request.input, OneOrMany::One(_)); let value = serde_json::to_value(&request).unwrap(); assert_eq!(value["input"], "hello"); diff --git a/src/gateway/types/openai/mod.rs b/src/gateway/types/openai/mod.rs index 048001b..0fd5e6e 100644 --- a/src/gateway/types/openai/mod.rs +++ b/src/gateway/types/openai/mod.rs @@ -321,6 +321,7 @@ pub struct ChunkFunctionCall { #[cfg(test)] mod tests { + use assert_matches::assert_matches; use pretty_assertions::assert_eq; use serde_json::json; @@ -360,7 +361,7 @@ mod tests { let req: ChatCompletionRequest = serde_json::from_value(json).unwrap(); assert!(req.tools.is_some()); assert_eq!(req.tools.as_ref().unwrap().len(), 1); - assert!(matches!(req.tool_choice, Some(ToolChoice::Mode(ref s)) if s == "auto")); + assert_matches!(req.tool_choice, Some(ToolChoice::Mode(ref s)) if s == "auto"); } #[test] @@ -438,7 +439,7 @@ mod tests { fn message_content_string_or_parts() { let json = json!({"role": "user", "content": "Hello"}); let msg: ChatMessage = serde_json::from_value(json).unwrap(); - assert!(matches!(msg.content, Some(MessageContent::Text(ref s)) if s == "Hello")); + assert_matches!(msg.content, Some(MessageContent::Text(ref s)) if s == "Hello"); let json = json!({ "role": "user", @@ -448,15 +449,15 @@ mod tests { ] }); let msg: ChatMessage = serde_json::from_value(json).unwrap(); - assert!(matches!(msg.content, Some(MessageContent::Parts(ref p)) if p.len() == 2)); + assert_matches!(msg.content, Some(MessageContent::Parts(ref p)) if p.len() == 2); } #[test] fn stop_condition_single_or_multiple() { let single: StopCondition = serde_json::from_value(json!("stop")).unwrap(); - assert!(matches!(single, StopCondition::Single(ref s) if s == "stop")); + assert_matches!(single, StopCondition::Single(ref s) if s == "stop"); let multiple: StopCondition = serde_json::from_value(json!(["stop", "end"])).unwrap(); - assert!(matches!(multiple, StopCondition::Multiple(ref v) if v.len() == 2)); + assert_matches!(multiple, StopCondition::Multiple(ref v) if v.len() == 2); } } diff --git a/src/gateway/types/openai/responses.rs b/src/gateway/types/openai/responses.rs index 960beb3..5194ebf 100644 --- a/src/gateway/types/openai/responses.rs +++ b/src/gateway/types/openai/responses.rs @@ -406,6 +406,7 @@ impl ResponsesApiStreamEvent { #[cfg(test)] mod tests { + use assert_matches::assert_matches; use pretty_assertions::assert_eq; use serde_json::json; @@ -419,7 +420,7 @@ mod tests { }); let req: ResponsesApiRequest = serde_json::from_value(json).unwrap(); assert_eq!(req.model, "gpt-4.1"); - assert!(matches!(req.input, ResponsesInput::Text(ref s) if s == "Hello")); + assert_matches!(req.input, ResponsesInput::Text(ref s) if s == "Hello"); } #[test] @@ -478,11 +479,11 @@ mod tests { assert!( matches!(req.conversation, Some(ConversationReference::Descriptor { ref id }) if id == "conv_123") ); - assert!(matches!( + assert_matches!( req.prompt_cache_retention, Some(PromptCacheRetention::InMemory) - )); - assert!(matches!(req.truncation, Some(Truncation::Auto))); + ); + assert_matches!(req.truncation, Some(Truncation::Auto)); assert_eq!(req.top_logprobs, Some(5)); let serialized = serde_json::to_value(&req).unwrap(); @@ -528,8 +529,8 @@ mod tests { let req: ResponsesApiRequest = serde_json::from_value(json).unwrap(); let tools = req.tools.unwrap(); assert_eq!(tools.len(), 3); - assert!(matches!(&tools[0], ResponsesTool::Function { name, .. } if name == "get_weather")); - assert!(matches!(&tools[1], ResponsesTool::WebSearch { .. })); + assert_matches!(&tools[0], ResponsesTool::Function { name, .. } if name == "get_weather"); + assert_matches!(&tools[1], ResponsesTool::WebSearch { .. }); assert!( matches!(&tools[2], ResponsesTool::FileSearch { vector_store_ids, .. } if vector_store_ids == &["vs_1"]) ); diff --git a/src/proxy/hooks/rate_limit/concurrent/local.rs b/src/proxy/hooks/rate_limit/concurrent/local.rs index e795eed..94f0225 100644 --- a/src/proxy/hooks/rate_limit/concurrent/local.rs +++ b/src/proxy/hooks/rate_limit/concurrent/local.rs @@ -111,7 +111,9 @@ impl ConcurrencyLimiter for LocalConcurrencyLimiter { #[cfg(test)] mod tests { + use assert_matches::assert_matches; use pretty_assertions::assert_eq; + use super::*; #[tokio::test] @@ -146,13 +148,13 @@ mod tests { let _p2 = limiter.try_acquire("k3", 2).await.unwrap(); let err = limiter.try_acquire("k3", 2).await; - assert!(matches!( - err, + assert_matches!( + err.map(|_| ()), Err(ConcurrencyError::Exceeded { limit: 2, current: 2, }) - )); + ); } #[tokio::test] @@ -187,7 +189,7 @@ mod tests { // "Hot reload" — decrease limit to 1; existing permits stay, new ones rejected let err = limiter.try_acquire("hrd", 1).await; - assert!(matches!(err, Err(ConcurrencyError::Exceeded { .. }))); + assert_matches!(err.map(|_| ()), Err(ConcurrencyError::Exceeded { .. })); // After existing permits drop, should work again with new limit drop(_p1); diff --git a/src/proxy/hooks/rate_limit/ratelimit/local.rs b/src/proxy/hooks/rate_limit/ratelimit/local.rs index 5184f38..9a4f3c3 100644 --- a/src/proxy/hooks/rate_limit/ratelimit/local.rs +++ b/src/proxy/hooks/rate_limit/ratelimit/local.rs @@ -63,6 +63,7 @@ impl RateLimiter for LocalRateLimiter { #[cfg(test)] mod tests { + use assert_matches::assert_matches; use pretty_assertions::assert_eq; use http::HeaderMap; @@ -313,8 +314,8 @@ mod tests { assert!(run_check(&e, CheckPhase::Pre).await.is_ok()); } let (m, err) = run_check(&e, CheckPhase::Pre).await.unwrap_err(); - assert!(matches!(m, RateLimitMetric::RPM)); - assert!(matches!(err, RateLimitError::Exceeded(_))); + assert_matches!(m, RateLimitMetric::RPM); + assert_matches!(err, RateLimitError::Exceeded(_)); } #[tokio::test] @@ -322,8 +323,8 @@ mod tests { let e = make_entity("exc_2", None, Some(100)); assert!(run_check(&e, CheckPhase::Post(90)).await.is_ok()); let (m, err) = run_check(&e, CheckPhase::Post(20)).await.unwrap_err(); - assert!(matches!(m, RateLimitMetric::TPM)); - assert!(matches!(err, RateLimitError::Exceeded(_))); + assert_matches!(m, RateLimitMetric::TPM); + assert_matches!(err, RateLimitError::Exceeded(_)); } #[tokio::test] diff --git a/src/proxy/provider.rs b/src/proxy/provider.rs index 77a82bd..56692da 100644 --- a/src/proxy/provider.rs +++ b/src/proxy/provider.rs @@ -145,6 +145,7 @@ fn default_bedrock_base_url(region: &str) -> String { #[cfg(test)] mod tests { + use assert_matches::assert_matches; use pretty_assertions::assert_eq; use reqwest::Url; @@ -299,10 +300,10 @@ mod tests { let error = provider_auth_and_base_url(&config).unwrap_err(); - assert!(matches!( + assert_matches!( error, crate::gateway::error::GatewayError::Internal(message) if message.contains("unsupported scheme") - )); + ); } }