diff --git a/src/config/entities/apikeys.rs b/src/config/entities/apikeys.rs index a8d6e85..8a2ce4c 100644 --- a/src/config/entities/apikeys.rs +++ b/src/config/entities/apikeys.rs @@ -75,6 +75,7 @@ impl ApiKeysStore { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use serde_json::json; use super::{SCHEMA, SCHEMA_VALIDATOR, format_evaluation_error}; diff --git a/src/config/entities/mod.rs b/src/config/entities/mod.rs index 7b26f3b..76754fe 100644 --- a/src/config/entities/mod.rs +++ b/src/config/entities/mod.rs @@ -448,6 +448,7 @@ impl EntityStore { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use std::{sync::Mutex, time::Duration}; use anyhow::Result; diff --git a/src/config/entities/models.rs b/src/config/entities/models.rs index 5e984ac..0f91d1b 100644 --- a/src/config/entities/models.rs +++ b/src/config/entities/models.rs @@ -93,6 +93,7 @@ impl ModelsStore { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use serde_json::json; use super::{SCHEMA, SCHEMA_VALIDATOR, format_evaluation_error}; diff --git a/src/config/entities/providers.rs b/src/config/entities/providers.rs index 4f0a9e5..067af22 100644 --- a/src/config/entities/providers.rs +++ b/src/config/entities/providers.rs @@ -112,6 +112,8 @@ impl ProvidersStore { #[cfg(test)] mod tests { + use assert_matches::assert_matches; + use pretty_assertions::assert_eq; use serde_json::json; use super::{SCHEMA, SCHEMA_VALIDATOR, format_evaluation_error}; @@ -212,9 +214,6 @@ mod tests { assert_eq!(provider.name, "openai-primary"); assert_eq!(provider.provider_type(), "openai"); - assert!(matches!( - provider.provider, - super::ProviderConfig::OpenAI(_) - )); + assert_matches!(provider.provider, super::ProviderConfig::OpenAI(_)); } } diff --git a/src/gateway/error.rs b/src/gateway/error.rs index d584263..12d4f7b 100644 --- a/src/gateway/error.rs +++ b/src/gateway/error.rs @@ -88,6 +88,7 @@ pub type Result = std::result::Result; #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use serde_json::json; use super::*; diff --git a/src/gateway/formats/anthropic_messages.rs b/src/gateway/formats/anthropic_messages.rs index 48195ef..8ad1520 100644 --- a/src/gateway/formats/anthropic_messages.rs +++ b/src/gateway/formats/anthropic_messages.rs @@ -1000,6 +1000,8 @@ fn openai_finish_reason_to_anthropic_stream(finish_reason: &str) -> String { #[cfg(test)] mod tests { + use assert_matches::assert_matches; + use pretty_assertions::assert_eq; use serde_json::json; use super::AnthropicMessagesFormat; @@ -1148,24 +1150,24 @@ mod tests { &BridgeContext::default(), ) .unwrap(); - assert!(matches!( + assert_matches!( &events[0], crate::gateway::types::anthropic::AnthropicStreamEvent::MessageStart { message } if message.id == "chatcmpl-123" && message.usage.input_tokens.is_none() && message.usage.output_tokens.is_none() - )); - assert!(matches!( + ); + assert_matches!( &events[1], crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockStart { index, .. } if *index == 0 - )); - assert!(matches!( + ); + assert_matches!( &events[2], crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockDelta { index, delta } if *index == 0 && matches!(delta, crate::gateway::types::anthropic::ContentDelta::TextDelta { text } if text == "hello") - )); + ); assert!( AnthropicMessagesFormat::from_hub_stream( @@ -1179,12 +1181,12 @@ mod tests { let end_events = AnthropicMessagesFormat::stream_end_events(&mut state, &BridgeContext::default()); - assert!(matches!( + assert_matches!( &end_events[0], crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockStop { index } if *index == 0 - )); - assert!(matches!( + ); + assert_matches!( &end_events[1], crate::gateway::types::anthropic::AnthropicStreamEvent::MessageDelta { delta, usage } if delta.stop_reason.is_none() @@ -1192,11 +1194,11 @@ mod tests { && usage.output_tokens == Some(9) && usage.cache_creation_input_tokens == Some(0) && usage.cache_read_input_tokens == Some(2) - )); - assert!(matches!( + ); + assert_matches!( &end_events[2], crate::gateway::types::anthropic::AnthropicStreamEvent::MessageStop - )); + ); } #[test] @@ -1242,16 +1244,16 @@ mod tests { assert_eq!(bridged.usage.output_tokens, 7); assert_eq!(bridged.usage.cache_read_input_tokens, 2); assert!(bridged.usage.cache_creation.is_none()); - assert!(matches!( + assert_matches!( &bridged.content[0], crate::gateway::types::anthropic::AnthropicContentBlock::Text { text, .. } if text == "Calling tool" - )); - assert!(matches!( + ); + assert_matches!( &bridged.content[1], crate::gateway::types::anthropic::AnthropicContentBlock::ToolUse { name, .. } if name == "get_weather" - )); + ); } #[test] @@ -1308,18 +1310,18 @@ mod tests { &BridgeContext::default(), ) .unwrap(); - assert!(matches!( + assert_matches!( &first_events[1], crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockStart { index, content_block } if *index == 0 && matches!(content_block, crate::gateway::types::anthropic::AnthropicContentBlock::ToolUse { name, .. } if name == "get_weather") - )); - assert!(matches!( + ); + assert_matches!( &first_events[2], crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockDelta { index, delta } if *index == 0 && matches!(delta, crate::gateway::types::anthropic::ContentDelta::InputJsonDelta { partial_json } if partial_json == "{\"city\"") - )); + ); let second_events = AnthropicMessagesFormat::from_hub_stream( &second_chunk, @@ -1327,20 +1329,20 @@ mod tests { &BridgeContext::default(), ) .unwrap(); - assert!(matches!( + assert_matches!( &second_events[0], crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockDelta { index, delta } if *index == 0 && matches!(delta, crate::gateway::types::anthropic::ContentDelta::InputJsonDelta { partial_json } if partial_json == ":\"SF\"}") - )); + ); let end_events = AnthropicMessagesFormat::stream_end_events(&mut state, &BridgeContext::default()); - assert!(matches!( + assert_matches!( &end_events[1], crate::gateway::types::anthropic::AnthropicStreamEvent::MessageDelta { delta, .. } if delta.stop_reason.as_deref() == Some("tool_use") - )); + ); } #[test] @@ -1374,11 +1376,11 @@ mod tests { &mut missing_type_state, &BridgeContext::default(), ); - assert!(matches!( + assert_matches!( missing_type_result, Err(GatewayError::Bridge(message)) if message.contains("requires tool call types") - )); + ); let mut invalid_type_state = super::AnthropicBridgeState::default(); let invalid_type_chunk: crate::gateway::types::openai::ChatCompletionChunk = @@ -1410,11 +1412,11 @@ mod tests { &mut invalid_type_state, &BridgeContext::default(), ); - assert!(matches!( + assert_matches!( invalid_type_result, Err(GatewayError::Bridge(message)) if message.contains("only supports function tool calls") - )); + ); } #[test] diff --git a/src/gateway/formats/openai/mod.rs b/src/gateway/formats/openai/mod.rs index bde0581..52b9bd8 100644 --- a/src/gateway/formats/openai/mod.rs +++ b/src/gateway/formats/openai/mod.rs @@ -68,6 +68,8 @@ impl ChatFormat for OpenAIChatFormat { #[cfg(test)] mod tests { + use assert_matches::assert_matches; + use pretty_assertions::assert_eq; use serde_json::json; use super::OpenAIChatFormat; @@ -224,10 +226,10 @@ mod tests { let error = OpenAIChatFormat::transform_native_stream_chunk(&provider, "data: {}", &mut ()) .unwrap_err(); - assert!(matches!( + assert_matches!( error, GatewayError::NativeNotSupported { provider } if provider == "dummy" - )); + ); assert!(OpenAIChatFormat::native_support(&provider).is_none()); } } diff --git a/src/gateway/gateway.rs b/src/gateway/gateway.rs index 3d1e924..7c176ef 100644 --- a/src/gateway/gateway.rs +++ b/src/gateway/gateway.rs @@ -410,6 +410,7 @@ mod tests { }, }; + use assert_matches::assert_matches; use aws_smithy_eventstream::frame::write_message_to; use aws_smithy_types::event_stream::{Header, HeaderValue as EventStreamHeaderValue, Message}; use axum::{Json, Router, extract::OriginalUri, routing::post}; @@ -419,6 +420,7 @@ mod tests { HeaderMap, HeaderValue, StatusCode, header::{AUTHORIZATION, CONTENT_TYPE, HeaderName}, }; + use pretty_assertions::assert_eq; use reqwest::Url; use serde_json::{Value, json}; use tokio::{net::TcpListener, sync::Mutex, task::JoinHandle}; @@ -907,11 +909,11 @@ mod tests { }; assert_eq!(response.model, "gpt-test"); - assert!(matches!( + assert_matches!( response.choices[0].message.content.as_ref(), Some(crate::gateway::types::openai::MessageContent::Text(text)) if text == "hello from hub" - )); + ); assert_eq!(usage.input_tokens, Some(7)); assert_eq!(usage.output_tokens, Some(9)); assert_eq!(usage.total_tokens, Some(16)); @@ -1064,11 +1066,11 @@ mod tests { ); assert!(response.id.starts_with("bedrock-")); assert_eq!(response.choices[0].finish_reason.as_deref(), Some("stop")); - assert!(matches!( + assert_matches!( response.choices[0].message.content.as_ref(), Some(crate::gateway::types::openai::MessageContent::Text(text)) if text == "hello from bedrock" - )); + ); assert_eq!(usage.total_tokens, Some(16)); let observed = observed.lock().await.take().unwrap(); @@ -1165,11 +1167,11 @@ mod tests { let error = gateway.embed(&request, &instance).await.unwrap_err(); - assert!(matches!( + assert_matches!( error, GatewayError::EmbeddingsNotSupported { provider } if provider == "native-test" - )); + ); } #[tokio::test] @@ -1296,10 +1298,10 @@ mod tests { assert_eq!(response.usage.input_tokens, 5); assert_eq!(response.usage.output_tokens, 9); assert_eq!(response.usage.cache_read_input_tokens, 2); - assert!(matches!( + assert_matches!( &response.content[0], AnthropicContentBlock::Text { text, .. } if text == "hello from hub" - )); + ); assert_eq!(usage.input_tokens, Some(7)); assert_eq!(usage.output_tokens, Some(9)); assert_eq!(usage.total_tokens, Some(16)); @@ -1395,39 +1397,39 @@ mod tests { let message_stop = stream.next().await.unwrap().unwrap(); assert!(stream.next().await.is_none()); - assert!(matches!( + assert_matches!( message_start, crate::gateway::types::anthropic::AnthropicStreamEvent::MessageStart { message } if message.id == "chatcmpl-789" - )); - assert!(matches!( + ); + assert_matches!( block_start, crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockStart { index, .. } if index == 0 - )); - assert!(matches!( + ); + assert_matches!( block_delta, crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockDelta { index, delta } if index == 0 && matches!(&delta, crate::gateway::types::anthropic::ContentDelta::TextDelta { text } if text == "hello") - )); - assert!(matches!( + ); + assert_matches!( block_stop, crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockStop { index } if index == 0 - )); - assert!(matches!( + ); + assert_matches!( message_delta, crate::gateway::types::anthropic::AnthropicStreamEvent::MessageDelta { usage, .. } if usage.input_tokens == Some(5) && usage.output_tokens == Some(9) && usage.cache_creation_input_tokens == Some(0) && usage.cache_read_input_tokens == Some(2) - )); - assert!(matches!( + ); + assert_matches!( message_stop, crate::gateway::types::anthropic::AnthropicStreamEvent::MessageStop - )); + ); let usage = usage_rx.await.unwrap(); assert_eq!(usage.input_tokens, Some(7)); @@ -1495,10 +1497,10 @@ mod tests { assert_eq!(response.id, "msg_123"); assert_eq!(response.model, "claude-3-5-sonnet-20241022"); - assert!(matches!( + assert_matches!( &response.content[0], AnthropicContentBlock::Text { text, .. } if text == "hello from native" - )); + ); assert_eq!(usage.input_tokens, Some(10)); assert_eq!(usage.output_tokens, Some(4)); assert_eq!(usage.total_tokens, Some(14)); @@ -1573,42 +1575,42 @@ mod tests { let message_stop = stream.next().await.unwrap().unwrap(); assert!(stream.next().await.is_none()); - assert!(matches!( + assert_matches!( message_start, crate::gateway::types::anthropic::AnthropicStreamEvent::MessageStart { message } if message.usage.input_tokens == Some(3) && message.usage.output_tokens == Some(1) && message.usage.cache_creation_input_tokens == Some(5) && message.usage.cache_read_input_tokens == Some(2) - )); - assert!(matches!( + ); + assert_matches!( block_start, crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockStart { index, .. } if index == 0 - )); - assert!(matches!( + ); + assert_matches!( block_delta, crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockDelta { index, delta } if index == 0 && matches!(&delta, crate::gateway::types::anthropic::ContentDelta::TextDelta { text } if text == "hello") - )); - assert!(matches!( + ); + assert_matches!( block_stop, crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockStop { index } if index == 0 - )); - assert!(matches!( + ); + assert_matches!( message_delta, crate::gateway::types::anthropic::AnthropicStreamEvent::MessageDelta { usage, .. } if usage.input_tokens == Some(3) && usage.output_tokens == Some(4) && usage.cache_creation_input_tokens.is_none() && usage.cache_read_input_tokens.is_none() - )); - assert!(matches!( + ); + assert_matches!( message_stop, crate::gateway::types::anthropic::AnthropicStreamEvent::MessageStop - )); + ); let usage = usage_rx.await.unwrap(); assert_eq!(usage.input_tokens, Some(10)); @@ -1938,11 +1940,11 @@ mod tests { .unwrap(); let result = gateway.chat_completion(&request, &instance).await; - assert!(matches!( - result, + assert_matches!( + result.map(|_| ()), Err(GatewayError::Validation(message)) if message.contains("JsonArrayStream") - )); + ); assert_eq!(request_count.load(Ordering::SeqCst), 0); server.abort(); @@ -1983,11 +1985,11 @@ mod tests { let result = gateway .chat::(&request, &instance) .await; - assert!(matches!( - result, + assert_matches!( + result.map(|_| ()), Err(GatewayError::Validation(message)) if message.contains("JsonArrayStream") - )); + ); assert_eq!(request_count.load(Ordering::SeqCst), 0); server.abort(); diff --git a/src/gateway/provider_instance.rs b/src/gateway/provider_instance.rs index c9ac050..bda5d5c 100644 --- a/src/gateway/provider_instance.rs +++ b/src/gateway/provider_instance.rs @@ -159,6 +159,8 @@ impl ProviderRegistryBuilder { #[cfg(test)] mod tests { + use assert_matches::assert_matches; + use pretty_assertions::assert_eq; use std::{borrow::Cow, sync::Arc}; use http::{ @@ -283,23 +285,23 @@ mod tests { ); let error = ProviderAuth::None.api_key().unwrap_err(); - assert!(matches!( + assert_matches!( error, GatewayError::Validation(message) if message.contains("ProviderAuth::ApiKey") - )); + ); } #[test] fn provider_auth_api_key_for_adds_provider_context() { let error = ProviderAuth::None.api_key_for("deepseek").unwrap_err(); - assert!(matches!( + assert_matches!( error, GatewayError::Validation(message) if message.contains("deepseek") && message.contains("ProviderAuth::ApiKey") - )); + ); } #[test] @@ -317,12 +319,12 @@ mod tests { let error = ProviderAuth::None .aws_static_credentials_for("bedrock") .unwrap_err(); - assert!(matches!( + assert_matches!( error, GatewayError::Validation(message) if message.contains("bedrock") && message.contains("ProviderAuth::AwsStatic") - )); + ); } #[test] @@ -351,12 +353,12 @@ mod tests { let error = instance.effective_base_url().unwrap_err(); - assert!(matches!( + assert_matches!( error, GatewayError::Validation(message) if message.contains("invalid-url") && message.contains("default_base_url") - )); + ); } #[test] @@ -400,11 +402,11 @@ mod tests { .err() .unwrap(); - assert!(matches!( + assert_matches!( error, GatewayError::Validation(message) if message.contains("dummy") && message.contains("already registered") - )); + ); } } diff --git a/src/gateway/providers/anthropic/mod.rs b/src/gateway/providers/anthropic/mod.rs index 9b56932..b5e78da 100644 --- a/src/gateway/providers/anthropic/mod.rs +++ b/src/gateway/providers/anthropic/mod.rs @@ -120,6 +120,8 @@ impl NativeAnthropicMessagesSupport for AnthropicDef { #[cfg(test)] mod tests { + use assert_matches::assert_matches; + use pretty_assertions::assert_eq; use serde_json::json; use super::AnthropicDef; @@ -181,7 +183,7 @@ mod tests { assert_eq!(body["model"], "claude-3-5-sonnet-20241022"); assert_eq!(parsed.id, "msg_123"); - assert!(matches!(events.as_slice(), [AnthropicStreamEvent::Ping])); + assert_matches!(events.as_slice(), [AnthropicStreamEvent::Ping]); } #[test] diff --git a/src/gateway/providers/anthropic/transform.rs b/src/gateway/providers/anthropic/transform.rs index 171995b..d5c3a1b 100644 --- a/src/gateway/providers/anthropic/transform.rs +++ b/src/gateway/providers/anthropic/transform.rs @@ -743,6 +743,8 @@ fn now_unix_secs() -> u64 { #[cfg(test)] mod tests { + use assert_matches::assert_matches; + use pretty_assertions::assert_eq; use serde_json::json; use super::{ @@ -789,10 +791,10 @@ mod tests { assert_eq!(anthropic.messages.len(), 1); assert_eq!(anthropic.messages[0].role, "user"); assert_eq!(anthropic.tools.as_ref().unwrap()[0].name, "get_weather"); - assert!(matches!( + assert_matches!( anthropic.tool_choice, Some(crate::gateway::types::anthropic::AnthropicToolChoice::Auto) - )); + ); assert_eq!( anthropic.metadata.as_ref().unwrap().user_id.as_deref(), Some("user-123") @@ -829,10 +831,10 @@ mod tests { openai.choices[0].finish_reason.as_deref(), Some("tool_calls") ); - assert!(matches!( + assert_matches!( openai.choices[0].message.content.as_ref(), Some(crate::gateway::types::openai::MessageContent::Text(text)) if text == "Let me check." - )); + ); let tool_call = &openai.choices[0].message.tool_calls.as_ref().unwrap()[0]; assert_eq!(tool_call.id, "tu_1"); assert_eq!(tool_call.function.name, "get_weather"); @@ -925,7 +927,7 @@ mod tests { ); let events = parse_anthropic_native_sse(r#"data: {"type":"ping"}"#).unwrap(); - assert!(matches!(events.as_slice(), [AnthropicStreamEvent::Ping])); + assert_matches!(events.as_slice(), [AnthropicStreamEvent::Ping]); } #[test] @@ -938,12 +940,12 @@ mod tests { .unwrap(); let error = openai_to_anthropic_request(&request).unwrap_err(); - assert!(matches!( + assert_matches!( error, crate::gateway::error::GatewayError::Bridge(message) if message.contains("n=1") && message.contains('2') - )); + ); } #[test] @@ -956,10 +958,10 @@ mod tests { .unwrap(); let error = openai_to_anthropic_request(&request).unwrap_err(); - assert!(matches!( + assert_matches!( error, crate::gateway::error::GatewayError::Bridge(message) if message.contains("requires tools") - )); + ); } } diff --git a/src/gateway/providers/bedrock/transform.rs b/src/gateway/providers/bedrock/transform.rs index 253b2f1..682149e 100644 --- a/src/gateway/providers/bedrock/transform.rs +++ b/src/gateway/providers/bedrock/transform.rs @@ -601,6 +601,8 @@ fn current_unix_timestamp() -> Result { #[cfg(test)] mod tests { + use assert_matches::assert_matches; + use pretty_assertions::assert_eq; use serde_json::json; use super::{ @@ -698,10 +700,10 @@ mod tests { mapped.choices[0].finish_reason.as_deref(), Some("tool_calls") ); - assert!(matches!( + assert_matches!( mapped.choices[0].message.content.as_ref(), Some(MessageContent::Text(text)) if text == "Need a tool." - )); + ); assert_eq!( mapped.choices[0].message.tool_calls.as_ref().unwrap()[0] .function diff --git a/src/gateway/providers/groq.rs b/src/gateway/providers/groq.rs index a3b3ca2..4e35b39 100644 --- a/src/gateway/providers/groq.rs +++ b/src/gateway/providers/groq.rs @@ -78,6 +78,7 @@ impl ProviderCapabilities for Groq {} #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use serde_json::json; use super::Groq; @@ -90,10 +91,10 @@ mod tests { fn provider_metadata_and_url_are_correct() { let provider = Groq; - pretty_assertions::assert_eq!(provider.name(), "groq"); - pretty_assertions::assert_eq!(provider.default_base_url(), "https://api.groq.com/openai"); + assert_eq!(provider.name(), "groq"); + assert_eq!(provider.default_base_url(), "https://api.groq.com/openai"); - pretty_assertions::assert_eq!( + assert_eq!( provider.build_url(provider.default_base_url(), "ignored"), "https://api.groq.com/openai/v1/chat/completions" ); diff --git a/src/gateway/providers/macros.rs b/src/gateway/providers/macros.rs index 009280a..e6cf80c 100644 --- a/src/gateway/providers/macros.rs +++ b/src/gateway/providers/macros.rs @@ -147,6 +147,8 @@ pub(crate) use provider; #[cfg(test)] mod tests { + use assert_matches::assert_matches; + use pretty_assertions::assert_eq; use std::borrow::Cow; use crate::gateway::{ @@ -202,11 +204,11 @@ mod tests { .build_auth_headers(&ProviderAuth::None) .unwrap_err(); - assert!(matches!( + assert_matches!( error, crate::gateway::error::GatewayError::Validation(message) if message.contains("macro-test") && message.contains("ProviderAuth::ApiKey") - )); + ); } } diff --git a/src/gateway/providers/mistral.rs b/src/gateway/providers/mistral.rs index 3f472d6..a7d3045 100644 --- a/src/gateway/providers/mistral.rs +++ b/src/gateway/providers/mistral.rs @@ -25,6 +25,8 @@ provider!(Mistral { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; + use super::Mistral; use crate::gateway::traits::{ChatTransform, ProviderMeta}; @@ -32,14 +34,14 @@ mod tests { fn provider_macro_expands_correctly() { let provider = Mistral; - pretty_assertions::assert_eq!(provider.name(), "mistral"); - pretty_assertions::assert_eq!(provider.default_base_url(), "https://api.mistral.ai"); + assert_eq!(provider.name(), "mistral"); + assert_eq!(provider.default_base_url(), "https://api.mistral.ai"); - pretty_assertions::assert_eq!( + assert_eq!( provider.build_url(provider.default_base_url(), "ignored"), "https://api.mistral.ai/v1/chat/completions" ); - pretty_assertions::assert_eq!(provider.default_quirks().tool_args_may_be_object, true); + assert!(provider.default_quirks().tool_args_may_be_object); } } diff --git a/src/gateway/providers/mod.rs b/src/gateway/providers/mod.rs index f5a0d96..b093b8e 100644 --- a/src/gateway/providers/mod.rs +++ b/src/gateway/providers/mod.rs @@ -60,6 +60,7 @@ pub fn default_provider_registry() -> Result { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use super::default_provider_registry; #[test] diff --git a/src/gateway/providers/openai.rs b/src/gateway/providers/openai.rs index f6f2320..ec5b5d5 100644 --- a/src/gateway/providers/openai.rs +++ b/src/gateway/providers/openai.rs @@ -56,6 +56,8 @@ impl ProviderCapabilities for OpenAIDef { #[cfg(test)] mod tests { + use assert_matches::assert_matches; + use pretty_assertions::assert_eq; use serde_json::json; use super::OpenAIDef; @@ -102,12 +104,12 @@ mod tests { .build_auth_headers(&ProviderAuth::None) .unwrap_err(); - assert!(matches!( + assert_matches!( error, crate::gateway::error::GatewayError::Validation(message) if message.contains("openai") && message.contains("ProviderAuth::ApiKey") - )); + ); } #[test] diff --git a/src/gateway/providers/openrouter.rs b/src/gateway/providers/openrouter.rs index a6d9358..cab920b 100644 --- a/src/gateway/providers/openrouter.rs +++ b/src/gateway/providers/openrouter.rs @@ -23,6 +23,7 @@ provider!(OpenRouter { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use super::OpenRouter; use crate::gateway::traits::ProviderMeta; @@ -30,11 +31,11 @@ mod tests { fn provider_macro_expands_correctly() { let provider = OpenRouter; - pretty_assertions::assert_eq!(provider.name(), "openrouter"); - pretty_assertions::assert_eq!(provider.default_base_url(), "https://openrouter.ai/api/v1"); - pretty_assertions::assert_eq!(provider.chat_endpoint_path("ignored"), "/chat/completions"); + assert_eq!(provider.name(), "openrouter"); + assert_eq!(provider.default_base_url(), "https://openrouter.ai/api/v1"); + assert_eq!(provider.chat_endpoint_path("ignored"), "/chat/completions"); - pretty_assertions::assert_eq!( + assert_eq!( provider.build_url(provider.default_base_url(), "ignored"), "https://openrouter.ai/api/v1/chat/completions" ); diff --git a/src/gateway/streams/bridged.rs b/src/gateway/streams/bridged.rs index 1d45a2d..b0af9b7 100644 --- a/src/gateway/streams/bridged.rs +++ b/src/gateway/streams/bridged.rs @@ -138,6 +138,7 @@ impl PinnedDrop for BridgedStream { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use std::sync::Arc; use futures::StreamExt; diff --git a/src/gateway/streams/hub.rs b/src/gateway/streams/hub.rs index d500740..23234a8 100644 --- a/src/gateway/streams/hub.rs +++ b/src/gateway/streams/hub.rs @@ -107,6 +107,7 @@ impl Stream for HubChunkStream { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use std::sync::Arc; use futures::StreamExt; diff --git a/src/gateway/streams/native.rs b/src/gateway/streams/native.rs index 813494d..dab2de4 100644 --- a/src/gateway/streams/native.rs +++ b/src/gateway/streams/native.rs @@ -119,6 +119,7 @@ impl PinnedDrop for NativeStream { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use std::sync::Arc; use futures::StreamExt; diff --git a/src/gateway/streams/reader/aws_event_stream.rs b/src/gateway/streams/reader/aws_event_stream.rs index e23fb01..16df7fe 100644 --- a/src/gateway/streams/reader/aws_event_stream.rs +++ b/src/gateway/streams/reader/aws_event_stream.rs @@ -180,6 +180,8 @@ fn build_aws_event_stream_exception_message(exception_type: &str, payload: &[u8] #[cfg(test)] mod tests { + use assert_matches::assert_matches; + use pretty_assertions::assert_eq; use aws_smithy_eventstream::frame::write_message_to; use aws_smithy_types::event_stream::{Header, HeaderValue, Message}; use bytes::Bytes; @@ -228,11 +230,11 @@ mod tests { let mut reader = aws_event_stream_reader(byte_stream); - assert!(matches!( + assert_matches!( reader.next().await.unwrap(), Err(GatewayError::Stream(message)) if message.contains("validationException") && message.contains("bad request") - )); + ); assert!(reader.next().await.is_none()); } diff --git a/src/gateway/streams/reader/sse.rs b/src/gateway/streams/reader/sse.rs index ce730b5..3cd83ba 100644 --- a/src/gateway/streams/reader/sse.rs +++ b/src/gateway/streams/reader/sse.rs @@ -71,6 +71,8 @@ where #[cfg(test)] mod tests { + use assert_matches::assert_matches; + use pretty_assertions::assert_eq; use bytes::Bytes; use futures::StreamExt; @@ -116,10 +118,10 @@ mod tests { let mut reader = sse_reader(byte_stream); - assert!(matches!( + assert_matches!( reader.next().await.unwrap(), Err(GatewayError::Http(_)) - )); + ); assert!(reader.next().await.is_none()); } } diff --git a/src/gateway/traits/chat_format.rs b/src/gateway/traits/chat_format.rs index 4e17c03..5b181a6 100644 --- a/src/gateway/traits/chat_format.rs +++ b/src/gateway/traits/chat_format.rs @@ -145,6 +145,8 @@ pub struct ChatStreamState { #[cfg(test)] mod tests { + use assert_matches::assert_matches; + use pretty_assertions::assert_eq; use std::borrow::Cow; use http::HeaderMap; @@ -282,10 +284,10 @@ mod tests { let native = NativeHandler::OpenAIResponses(&provider); let error = DummyFormat::call_native(&native, &json!({}), false).unwrap_err(); - assert!(matches!( + assert_matches!( error, GatewayError::NativeNotSupported { provider } if provider == "dummy-native-provider" - )); + ); } #[test] @@ -294,12 +296,12 @@ mod tests { let native = NativeHandler::OpenAIResponses(&provider); let error = DummyFormat::parse_native_response(&native, json!({})).unwrap_err(); - assert!(matches!( + assert_matches!( error, GatewayError::Bridge(message) if message.contains("parse_native_response called on a non-native format") && message.contains("dummy-native-provider") - )); + ); } #[test] diff --git a/src/gateway/traits/provider.rs b/src/gateway/traits/provider.rs index 96e56a3..bbcc738 100644 --- a/src/gateway/traits/provider.rs +++ b/src/gateway/traits/provider.rs @@ -319,6 +319,7 @@ pub trait ImageGenTransform: Send + Sync + 'static {} #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use std::borrow::Cow; use http::HeaderMap; @@ -483,8 +484,8 @@ mod tests { let body = provider.transform_embeddings_request(&request).unwrap(); match body { EmbedRequestBody::Json(value) => { - pretty_assertions::assert_eq!(value["model"], "text-embedding-3-large"); - pretty_assertions::assert_eq!(value["input"], "hello"); + assert_eq!(value["model"], "text-embedding-3-large"); + assert_eq!(value["input"], "hello"); } } @@ -501,14 +502,14 @@ mod tests { }))) .unwrap(); - pretty_assertions::assert_eq!(response.data.len(), 1); + assert_eq!(response.data.len(), 1); let usage = match response.usage { Some(usage) => usage, None => panic!("expected usage in embedding response"), }; - pretty_assertions::assert_eq!(usage.total_tokens, 2); + assert_eq!(usage.total_tokens, 2); } #[test] diff --git a/src/gateway/types/anthropic.rs b/src/gateway/types/anthropic.rs index f358d42..4b827a3 100644 --- a/src/gateway/types/anthropic.rs +++ b/src/gateway/types/anthropic.rs @@ -325,6 +325,8 @@ pub struct AnthropicError { #[cfg(test)] mod tests { + use assert_matches::assert_matches; + use pretty_assertions::assert_eq; use serde_json::json; use super::*; @@ -357,7 +359,7 @@ mod tests { "messages": [{"role": "user", "content": "Hi"}] }); let req: AnthropicMessagesRequest = serde_json::from_value(json).unwrap(); - assert!(matches!(req.system, Some(SystemPrompt::Text(ref s)) if s == "You are helpful.")); + assert_matches!(req.system, Some(SystemPrompt::Text(ref s)) if s == "You are helpful."); } #[test] @@ -411,7 +413,7 @@ mod tests { }); let req: AnthropicMessagesRequest = serde_json::from_value(json).unwrap(); assert_eq!(req.tools.as_ref().unwrap().len(), 1); - assert!(matches!(req.tool_choice, Some(AnthropicToolChoice::Auto))); + assert_matches!(req.tool_choice, Some(AnthropicToolChoice::Auto)); } #[test] diff --git a/src/gateway/types/common.rs b/src/gateway/types/common.rs index 6ba6a3b..7bb5865 100644 --- a/src/gateway/types/common.rs +++ b/src/gateway/types/common.rs @@ -158,6 +158,7 @@ pub struct OpenAIResponsesExtras { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use super::*; #[test] diff --git a/src/gateway/types/embed.rs b/src/gateway/types/embed.rs index d0fe611..e3aae19 100644 --- a/src/gateway/types/embed.rs +++ b/src/gateway/types/embed.rs @@ -60,6 +60,8 @@ pub struct EmbeddingResponse { #[cfg(test)] mod tests { + use assert_matches::assert_matches; + use pretty_assertions::assert_eq; use serde_json::json; use super::{EmbeddingRequest, EmbeddingResponse, OneOrMany}; @@ -76,7 +78,7 @@ mod tests { .unwrap(); assert_eq!(request.model, "text-embedding-3-large"); - assert!(matches!(request.input, OneOrMany::Many(_))); + assert_matches!(request.input, OneOrMany::Many(_)); let value = serde_json::to_value(&request).unwrap(); assert_eq!(value["dimensions"], 256); @@ -91,7 +93,7 @@ mod tests { })) .unwrap(); - assert!(matches!(request.input, OneOrMany::One(_))); + assert_matches!(request.input, OneOrMany::One(_)); let value = serde_json::to_value(&request).unwrap(); assert_eq!(value["input"], "hello"); diff --git a/src/gateway/types/openai/mod.rs b/src/gateway/types/openai/mod.rs index 1413500..0fd5e6e 100644 --- a/src/gateway/types/openai/mod.rs +++ b/src/gateway/types/openai/mod.rs @@ -321,6 +321,8 @@ pub struct ChunkFunctionCall { #[cfg(test)] mod tests { + use assert_matches::assert_matches; + use pretty_assertions::assert_eq; use serde_json::json; use super::*; @@ -359,7 +361,7 @@ mod tests { let req: ChatCompletionRequest = serde_json::from_value(json).unwrap(); assert!(req.tools.is_some()); assert_eq!(req.tools.as_ref().unwrap().len(), 1); - assert!(matches!(req.tool_choice, Some(ToolChoice::Mode(ref s)) if s == "auto")); + assert_matches!(req.tool_choice, Some(ToolChoice::Mode(ref s)) if s == "auto"); } #[test] @@ -437,7 +439,7 @@ mod tests { fn message_content_string_or_parts() { let json = json!({"role": "user", "content": "Hello"}); let msg: ChatMessage = serde_json::from_value(json).unwrap(); - assert!(matches!(msg.content, Some(MessageContent::Text(ref s)) if s == "Hello")); + assert_matches!(msg.content, Some(MessageContent::Text(ref s)) if s == "Hello"); let json = json!({ "role": "user", @@ -447,15 +449,15 @@ mod tests { ] }); let msg: ChatMessage = serde_json::from_value(json).unwrap(); - assert!(matches!(msg.content, Some(MessageContent::Parts(ref p)) if p.len() == 2)); + assert_matches!(msg.content, Some(MessageContent::Parts(ref p)) if p.len() == 2); } #[test] fn stop_condition_single_or_multiple() { let single: StopCondition = serde_json::from_value(json!("stop")).unwrap(); - assert!(matches!(single, StopCondition::Single(ref s) if s == "stop")); + assert_matches!(single, StopCondition::Single(ref s) if s == "stop"); let multiple: StopCondition = serde_json::from_value(json!(["stop", "end"])).unwrap(); - assert!(matches!(multiple, StopCondition::Multiple(ref v) if v.len() == 2)); + assert_matches!(multiple, StopCondition::Multiple(ref v) if v.len() == 2); } } diff --git a/src/gateway/types/openai/responses.rs b/src/gateway/types/openai/responses.rs index f6a3093..5194ebf 100644 --- a/src/gateway/types/openai/responses.rs +++ b/src/gateway/types/openai/responses.rs @@ -406,6 +406,8 @@ impl ResponsesApiStreamEvent { #[cfg(test)] mod tests { + use assert_matches::assert_matches; + use pretty_assertions::assert_eq; use serde_json::json; use super::*; @@ -418,7 +420,7 @@ mod tests { }); let req: ResponsesApiRequest = serde_json::from_value(json).unwrap(); assert_eq!(req.model, "gpt-4.1"); - assert!(matches!(req.input, ResponsesInput::Text(ref s) if s == "Hello")); + assert_matches!(req.input, ResponsesInput::Text(ref s) if s == "Hello"); } #[test] @@ -477,11 +479,11 @@ mod tests { assert!( matches!(req.conversation, Some(ConversationReference::Descriptor { ref id }) if id == "conv_123") ); - assert!(matches!( + assert_matches!( req.prompt_cache_retention, Some(PromptCacheRetention::InMemory) - )); - assert!(matches!(req.truncation, Some(Truncation::Auto))); + ); + assert_matches!(req.truncation, Some(Truncation::Auto)); assert_eq!(req.top_logprobs, Some(5)); let serialized = serde_json::to_value(&req).unwrap(); @@ -527,8 +529,8 @@ mod tests { let req: ResponsesApiRequest = serde_json::from_value(json).unwrap(); let tools = req.tools.unwrap(); assert_eq!(tools.len(), 3); - assert!(matches!(&tools[0], ResponsesTool::Function { name, .. } if name == "get_weather")); - assert!(matches!(&tools[1], ResponsesTool::WebSearch { .. })); + assert_matches!(&tools[0], ResponsesTool::Function { name, .. } if name == "get_weather"); + assert_matches!(&tools[1], ResponsesTool::WebSearch { .. }); assert!( matches!(&tools[2], ResponsesTool::FileSearch { vector_store_ids, .. } if vector_store_ids == &["vs_1"]) ); diff --git a/src/proxy/handlers/chat_completions/span_attributes/tests.rs b/src/proxy/handlers/chat_completions/span_attributes/tests.rs index f3270c5..b6268fc 100644 --- a/src/proxy/handlers/chat_completions/span_attributes/tests.rs +++ b/src/proxy/handlers/chat_completions/span_attributes/tests.rs @@ -1,3 +1,4 @@ +use pretty_assertions::assert_eq; use serde_json::{Value, json}; use super::{StreamOutputCollector, request_span_properties, response_span_properties}; diff --git a/src/proxy/handlers/messages/types.rs b/src/proxy/handlers/messages/types.rs index e02b155..2602220 100644 --- a/src/proxy/handlers/messages/types.rs +++ b/src/proxy/handlers/messages/types.rs @@ -181,6 +181,7 @@ fn gateway_error_type(error: &GatewayError) -> &'static str { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use http::StatusCode; use serde_json::json; diff --git a/src/proxy/hooks/rate_limit/concurrent/local.rs b/src/proxy/hooks/rate_limit/concurrent/local.rs index 69c2f46..94f0225 100644 --- a/src/proxy/hooks/rate_limit/concurrent/local.rs +++ b/src/proxy/hooks/rate_limit/concurrent/local.rs @@ -111,6 +111,9 @@ impl ConcurrencyLimiter for LocalConcurrencyLimiter { #[cfg(test)] mod tests { + use assert_matches::assert_matches; + use pretty_assertions::assert_eq; + use super::*; #[tokio::test] @@ -145,13 +148,13 @@ mod tests { let _p2 = limiter.try_acquire("k3", 2).await.unwrap(); let err = limiter.try_acquire("k3", 2).await; - assert!(matches!( - err, + assert_matches!( + err.map(|_| ()), Err(ConcurrencyError::Exceeded { limit: 2, current: 2, }) - )); + ); } #[tokio::test] @@ -186,7 +189,7 @@ mod tests { // "Hot reload" — decrease limit to 1; existing permits stay, new ones rejected let err = limiter.try_acquire("hrd", 1).await; - assert!(matches!(err, Err(ConcurrencyError::Exceeded { .. }))); + assert_matches!(err.map(|_| ()), Err(ConcurrencyError::Exceeded { .. })); // After existing permits drop, should work again with new limit drop(_p1); diff --git a/src/proxy/hooks/rate_limit/concurrent/mod.rs b/src/proxy/hooks/rate_limit/concurrent/mod.rs index 98d9ec1..43fe736 100644 --- a/src/proxy/hooks/rate_limit/concurrent/mod.rs +++ b/src/proxy/hooks/rate_limit/concurrent/mod.rs @@ -82,6 +82,7 @@ pub fn get_concurrency_limiter() -> Arc { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use super::*; #[test] diff --git a/src/proxy/hooks/rate_limit/concurrent/utils.rs b/src/proxy/hooks/rate_limit/concurrent/utils.rs index 620eed0..08dbebd 100644 --- a/src/proxy/hooks/rate_limit/concurrent/utils.rs +++ b/src/proxy/hooks/rate_limit/concurrent/utils.rs @@ -114,6 +114,7 @@ impl IntoResponse for ConcurrencyLimitResponse { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use super::*; #[test] diff --git a/src/proxy/hooks/rate_limit/ratelimit/local.rs b/src/proxy/hooks/rate_limit/ratelimit/local.rs index 12a11f0..9a4f3c3 100644 --- a/src/proxy/hooks/rate_limit/ratelimit/local.rs +++ b/src/proxy/hooks/rate_limit/ratelimit/local.rs @@ -63,6 +63,8 @@ impl RateLimiter for LocalRateLimiter { #[cfg(test)] mod tests { + use assert_matches::assert_matches; + use pretty_assertions::assert_eq; use http::HeaderMap; use super::*; @@ -312,8 +314,8 @@ mod tests { assert!(run_check(&e, CheckPhase::Pre).await.is_ok()); } let (m, err) = run_check(&e, CheckPhase::Pre).await.unwrap_err(); - assert!(matches!(m, RateLimitMetric::RPM)); - assert!(matches!(err, RateLimitError::Exceeded(_))); + assert_matches!(m, RateLimitMetric::RPM); + assert_matches!(err, RateLimitError::Exceeded(_)); } #[tokio::test] @@ -321,8 +323,8 @@ mod tests { let e = make_entity("exc_2", None, Some(100)); assert!(run_check(&e, CheckPhase::Post(90)).await.is_ok()); let (m, err) = run_check(&e, CheckPhase::Post(20)).await.unwrap_err(); - assert!(matches!(m, RateLimitMetric::TPM)); - assert!(matches!(err, RateLimitError::Exceeded(_))); + assert_matches!(m, RateLimitMetric::TPM); + assert_matches!(err, RateLimitError::Exceeded(_)); } #[tokio::test] diff --git a/src/proxy/hooks/rate_limit/ratelimit/mod.rs b/src/proxy/hooks/rate_limit/ratelimit/mod.rs index 559e5a8..df8d7f5 100644 --- a/src/proxy/hooks/rate_limit/ratelimit/mod.rs +++ b/src/proxy/hooks/rate_limit/ratelimit/mod.rs @@ -99,6 +99,7 @@ pub fn get_rate_limiter() -> Arc { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use super::*; use crate::config::entities::types::RateLimit; diff --git a/src/proxy/hooks/rate_limit/ratelimit/utils.rs b/src/proxy/hooks/rate_limit/ratelimit/utils.rs index 14113d3..5446a10 100644 --- a/src/proxy/hooks/rate_limit/ratelimit/utils.rs +++ b/src/proxy/hooks/rate_limit/ratelimit/utils.rs @@ -255,6 +255,7 @@ impl IntoResponse for RateLimitResponse { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use std::time::Duration; use super::*; diff --git a/src/proxy/provider.rs b/src/proxy/provider.rs index 7a10b41..56692da 100644 --- a/src/proxy/provider.rs +++ b/src/proxy/provider.rs @@ -145,6 +145,8 @@ fn default_bedrock_base_url(region: &str) -> String { #[cfg(test)] mod tests { + use assert_matches::assert_matches; + use pretty_assertions::assert_eq; use reqwest::Url; use super::provider_auth_and_base_url; @@ -298,10 +300,10 @@ mod tests { let error = provider_auth_and_base_url(&config).unwrap_err(); - assert!(matches!( + assert_matches!( error, crate::gateway::error::GatewayError::Internal(message) if message.contains("unsupported scheme") - )); + ); } } diff --git a/src/utils/instance.rs b/src/utils/instance.rs index b972ceb..afddd77 100644 --- a/src/utils/instance.rs +++ b/src/utils/instance.rs @@ -97,6 +97,7 @@ fn write_id_file(path: &Path, id: &str) -> Result<()> { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use std::path::PathBuf; use super::*;