Skip to content

Commit

Permalink
feat: support gemini-2.0-flash-thinking-exp (#1067)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored Dec 26, 2024
1 parent 90475d3 commit a109e7e
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 31 deletions.
15 changes: 15 additions & 0 deletions models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@
output_price: 0
supports_vision: true
supports_function_calling: true
- name: gemini-2.0-flash-thinking-exp
max_input_tokens: 32768
max_output_tokens: 8192
input_price: 0
output_price: 0
supports_vision: true
- name: gemini-exp-1206
max_input_tokens: 32768
max_output_tokens: 8192
Expand Down Expand Up @@ -452,6 +458,15 @@
output_price: 0.075
supports_vision: true
supports_function_calling: true
- name: gemini-2.0-flash-exp
max_input_tokens: 1048576
max_output_tokens: 8192
supports_vision: true
supports_function_calling: true
- name: gemini-2.0-flash-thinking-exp-1219
max_input_tokens: 32768
max_output_tokens: 8192
supports_vision: true
- name: claude-3-5-sonnet-v2@20241022
max_input_tokens: 200000
max_output_tokens: 8192
Expand Down
60 changes: 29 additions & 31 deletions src/client/vertexai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,24 +197,25 @@ pub async fn gemini_chat_completions_streaming(
let handle = |value: &str| -> Result<()> {
let data: Value = serde_json::from_str(value)?;
debug!("stream-data: {data}");
if let Some(text) = data["candidates"][0]["content"]["parts"][0]["text"].as_str() {
if !text.is_empty() {
handler.text(text)?;
}
} else if let Some("SAFETY") = data["promptFeedback"]["blockReason"]
.as_str()
.or_else(|| data["candidates"][0]["finishReason"].as_str())
{
bail!("Content Blocked")
} else if let Some(parts) = data["candidates"][0]["content"]["parts"].as_array() {
for part in parts {
if let (Some(name), Some(args)) = (
if let Some(parts) = data["candidates"][0]["content"]["parts"].as_array() {
for (i, part) in parts.iter().enumerate() {
if let Some(text) = part["text"].as_str() {
if i > 0 {
handler.text("\n\n")?;
}
handler.text(text)?;
} else if let (Some(name), Some(args)) = (
part["functionCall"]["name"].as_str(),
part["functionCall"]["args"].as_object(),
) {
handler.tool_call(ToolCall::new(name.to_string(), json!(args), None))?;
}
}
} else if let Some("SAFETY") = data["promptFeedback"]["blockReason"]
.as_str()
.or_else(|| data["candidates"][0]["finishReason"].as_str())
{
bail!("Blocked due to safety")
}

Ok(())
Expand Down Expand Up @@ -257,38 +258,35 @@ struct EmbeddingsResBodyPredictionEmbeddings {
}

fn gemini_extract_chat_completions_text(data: &Value) -> Result<ChatCompletionsOutput> {
let text = data["candidates"][0]["content"]["parts"][0]["text"]
.as_str()
.unwrap_or_default();

let mut text_parts = vec![];
let mut tool_calls = vec![];
if let Some(parts) = data["candidates"][0]["content"]["parts"].as_array() {
tool_calls = parts
.iter()
.filter_map(|part| {
if let (Some(name), Some(args)) = (
part["functionCall"]["name"].as_str(),
part["functionCall"]["args"].as_object(),
) {
Some(ToolCall::new(name.to_string(), json!(args), None))
} else {
None
}
})
.collect()
for part in parts {
if let Some(text) = part["text"].as_str() {
text_parts.push(text);
}
if let (Some(name), Some(args)) = (
part["functionCall"]["name"].as_str(),
part["functionCall"]["args"].as_object(),
) {
tool_calls.push(ToolCall::new(name.to_string(), json!(args), None));
}
}
}

let text = text_parts.join("\n\n");
if text.is_empty() && tool_calls.is_empty() {
if let Some("SAFETY") = data["promptFeedback"]["blockReason"]
.as_str()
.or_else(|| data["candidates"][0]["finishReason"].as_str())
{
bail!("Content Blocked")
bail!("Blocked due to safety")
} else {
bail!("Invalid response data: {data}");
}
}
let output = ChatCompletionsOutput {
text: text.to_string(),
text,
tool_calls,
id: None,
input_tokens: data["usageMetadata"]["promptTokenCount"].as_u64(),
Expand Down

0 comments on commit a109e7e

Please sign in to comment.