Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/config/entities/apikeys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ impl ApiKeysStore {

#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use serde_json::json;

use super::{SCHEMA, SCHEMA_VALIDATOR, format_evaluation_error};
Expand Down
1 change: 1 addition & 0 deletions src/config/entities/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ impl<T: DeserializeOwned + Clone + Send + Sync + 'static> EntityStore<T> {

#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use std::{sync::Mutex, time::Duration};

use anyhow::Result;
Expand Down
1 change: 1 addition & 0 deletions src/config/entities/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ impl ModelsStore {

#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use serde_json::json;

use super::{SCHEMA, SCHEMA_VALIDATOR, format_evaluation_error};
Expand Down
7 changes: 3 additions & 4 deletions src/config/entities/providers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ impl ProvidersStore {

#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use pretty_assertions::assert_eq;
use serde_json::json;

use super::{SCHEMA, SCHEMA_VALIDATOR, format_evaluation_error};
Expand Down Expand Up @@ -212,9 +214,6 @@ mod tests {

assert_eq!(provider.name, "openai-primary");
assert_eq!(provider.provider_type(), "openai");
assert!(matches!(
provider.provider,
super::ProviderConfig::OpenAI(_)
));
assert_matches!(provider.provider, super::ProviderConfig::OpenAI(_));
}
}
1 change: 1 addition & 0 deletions src/gateway/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ pub type Result<T> = std::result::Result<T, GatewayError>;

#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use serde_json::json;

use super::*;
Expand Down
58 changes: 30 additions & 28 deletions src/gateway/formats/anthropic_messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1000,6 +1000,8 @@ fn openai_finish_reason_to_anthropic_stream(finish_reason: &str) -> String {

#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use pretty_assertions::assert_eq;
use serde_json::json;

use super::AnthropicMessagesFormat;
Expand Down Expand Up @@ -1148,24 +1150,24 @@ mod tests {
&BridgeContext::default(),
)
.unwrap();
assert!(matches!(
assert_matches!(
&events[0],
crate::gateway::types::anthropic::AnthropicStreamEvent::MessageStart { message }
if message.id == "chatcmpl-123"
&& message.usage.input_tokens.is_none()
&& message.usage.output_tokens.is_none()
));
assert!(matches!(
);
assert_matches!(
&events[1],
crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockStart { index, .. }
if *index == 0
));
assert!(matches!(
);
assert_matches!(
&events[2],
crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockDelta { index, delta }
if *index == 0
&& matches!(delta, crate::gateway::types::anthropic::ContentDelta::TextDelta { text } if text == "hello")
));
);

assert!(
AnthropicMessagesFormat::from_hub_stream(
Expand All @@ -1179,24 +1181,24 @@ mod tests {

let end_events =
AnthropicMessagesFormat::stream_end_events(&mut state, &BridgeContext::default());
assert!(matches!(
assert_matches!(
&end_events[0],
crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockStop { index }
if *index == 0
));
assert!(matches!(
);
assert_matches!(
&end_events[1],
crate::gateway::types::anthropic::AnthropicStreamEvent::MessageDelta { delta, usage }
if delta.stop_reason.is_none()
&& usage.input_tokens == Some(5)
&& usage.output_tokens == Some(9)
&& usage.cache_creation_input_tokens == Some(0)
&& usage.cache_read_input_tokens == Some(2)
));
assert!(matches!(
);
assert_matches!(
&end_events[2],
crate::gateway::types::anthropic::AnthropicStreamEvent::MessageStop
));
);
}

#[test]
Expand Down Expand Up @@ -1242,16 +1244,16 @@ mod tests {
assert_eq!(bridged.usage.output_tokens, 7);
assert_eq!(bridged.usage.cache_read_input_tokens, 2);
assert!(bridged.usage.cache_creation.is_none());
assert!(matches!(
assert_matches!(
&bridged.content[0],
crate::gateway::types::anthropic::AnthropicContentBlock::Text { text, .. }
if text == "Calling tool"
));
assert!(matches!(
);
assert_matches!(
&bridged.content[1],
crate::gateway::types::anthropic::AnthropicContentBlock::ToolUse { name, .. }
if name == "get_weather"
));
);
}

#[test]
Expand Down Expand Up @@ -1308,39 +1310,39 @@ mod tests {
&BridgeContext::default(),
)
.unwrap();
assert!(matches!(
assert_matches!(
&first_events[1],
crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockStart { index, content_block }
if *index == 0
&& matches!(content_block, crate::gateway::types::anthropic::AnthropicContentBlock::ToolUse { name, .. } if name == "get_weather")
));
assert!(matches!(
);
assert_matches!(
&first_events[2],
crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockDelta { index, delta }
if *index == 0
&& matches!(delta, crate::gateway::types::anthropic::ContentDelta::InputJsonDelta { partial_json } if partial_json == "{\"city\"")
));
);

let second_events = AnthropicMessagesFormat::from_hub_stream(
&second_chunk,
&mut state,
&BridgeContext::default(),
)
.unwrap();
assert!(matches!(
assert_matches!(
&second_events[0],
crate::gateway::types::anthropic::AnthropicStreamEvent::ContentBlockDelta { index, delta }
if *index == 0
&& matches!(delta, crate::gateway::types::anthropic::ContentDelta::InputJsonDelta { partial_json } if partial_json == ":\"SF\"}")
));
);

let end_events =
AnthropicMessagesFormat::stream_end_events(&mut state, &BridgeContext::default());
assert!(matches!(
assert_matches!(
&end_events[1],
crate::gateway::types::anthropic::AnthropicStreamEvent::MessageDelta { delta, .. }
if delta.stop_reason.as_deref() == Some("tool_use")
));
);
}

#[test]
Expand Down Expand Up @@ -1374,11 +1376,11 @@ mod tests {
&mut missing_type_state,
&BridgeContext::default(),
);
assert!(matches!(
assert_matches!(
missing_type_result,
Err(GatewayError::Bridge(message))
if message.contains("requires tool call types")
));
);

let mut invalid_type_state = super::AnthropicBridgeState::default();
let invalid_type_chunk: crate::gateway::types::openai::ChatCompletionChunk =
Expand Down Expand Up @@ -1410,11 +1412,11 @@ mod tests {
&mut invalid_type_state,
&BridgeContext::default(),
);
assert!(matches!(
assert_matches!(
invalid_type_result,
Err(GatewayError::Bridge(message))
if message.contains("only supports function tool calls")
));
);
}

#[test]
Expand Down
6 changes: 4 additions & 2 deletions src/gateway/formats/openai/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ impl ChatFormat for OpenAIChatFormat {

#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use pretty_assertions::assert_eq;
use serde_json::json;

use super::OpenAIChatFormat;
Expand Down Expand Up @@ -224,10 +226,10 @@ mod tests {
let error = OpenAIChatFormat::transform_native_stream_chunk(&provider, "data: {}", &mut ())
.unwrap_err();

assert!(matches!(
assert_matches!(
error,
GatewayError::NativeNotSupported { provider } if provider == "dummy"
));
);
assert!(OpenAIChatFormat::native_support(&provider).is_none());
}
}
Loading