From a109e7e9672d15e86baef3281ede2572e6224158 Mon Sep 17 00:00:00 2001 From: sigoden Date: Thu, 26 Dec 2024 18:32:53 +0800 Subject: [PATCH] feat: support gemini-2.0-flash-thinking-exp (#1067) --- models.yaml | 15 +++++++++++ src/client/vertexai.rs | 60 ++++++++++++++++++++---------------------- 2 files changed, 44 insertions(+), 31 deletions(-) diff --git a/models.yaml b/models.yaml index 353251cb..d76fc141 100644 --- a/models.yaml +++ b/models.yaml @@ -114,6 +114,12 @@ output_price: 0 supports_vision: true supports_function_calling: true + - name: gemini-2.0-flash-thinking-exp + max_input_tokens: 32768 + max_output_tokens: 8192 + input_price: 0 + output_price: 0 + supports_vision: true - name: gemini-exp-1206 max_input_tokens: 32768 max_output_tokens: 8192 @@ -452,6 +458,15 @@ output_price: 0.075 supports_vision: true supports_function_calling: true + - name: gemini-2.0-flash-exp + max_input_tokens: 1048576 + max_output_tokens: 8192 + supports_vision: true + supports_function_calling: true + - name: gemini-2.0-flash-thinking-exp-1219 + max_input_tokens: 32768 + max_output_tokens: 8192 + supports_vision: true - name: claude-3-5-sonnet-v2@20241022 max_input_tokens: 200000 max_output_tokens: 8192 diff --git a/src/client/vertexai.rs b/src/client/vertexai.rs index 7b731648..1612b347 100644 --- a/src/client/vertexai.rs +++ b/src/client/vertexai.rs @@ -197,24 +197,25 @@ pub async fn gemini_chat_completions_streaming( let handle = |value: &str| -> Result<()> { let data: Value = serde_json::from_str(value)?; debug!("stream-data: {data}"); - if let Some(text) = data["candidates"][0]["content"]["parts"][0]["text"].as_str() { - if !text.is_empty() { - handler.text(text)?; - } - } else if let Some("SAFETY") = data["promptFeedback"]["blockReason"] - .as_str() - .or_else(|| data["candidates"][0]["finishReason"].as_str()) - { - bail!("Content Blocked") - } else if let Some(parts) = data["candidates"][0]["content"]["parts"].as_array() { - for part in parts { - if let (Some(name), Some(args)) = ( + if let Some(parts) = data["candidates"][0]["content"]["parts"].as_array() { + for (i, part) in parts.iter().enumerate() { + if let Some(text) = part["text"].as_str() { + if i > 0 { + handler.text("\n\n")?; + } + handler.text(text)?; + } else if let (Some(name), Some(args)) = ( part["functionCall"]["name"].as_str(), part["functionCall"]["args"].as_object(), ) { handler.tool_call(ToolCall::new(name.to_string(), json!(args), None))?; } } + } else if let Some("SAFETY") = data["promptFeedback"]["blockReason"] + .as_str() + .or_else(|| data["candidates"][0]["finishReason"].as_str()) + { + bail!("Blocked due to safety") } Ok(()) @@ -257,38 +258,35 @@ struct EmbeddingsResBodyPredictionEmbeddings { } fn gemini_extract_chat_completions_text(data: &Value) -> Result { - let text = data["candidates"][0]["content"]["parts"][0]["text"] - .as_str() - .unwrap_or_default(); - + let mut text_parts = vec![]; let mut tool_calls = vec![]; if let Some(parts) = data["candidates"][0]["content"]["parts"].as_array() { - tool_calls = parts - .iter() - .filter_map(|part| { - if let (Some(name), Some(args)) = ( - part["functionCall"]["name"].as_str(), - part["functionCall"]["args"].as_object(), - ) { - Some(ToolCall::new(name.to_string(), json!(args), None)) - } else { - None - } - }) - .collect() + for part in parts { + if let Some(text) = part["text"].as_str() { + text_parts.push(text); + } + if let (Some(name), Some(args)) = ( + part["functionCall"]["name"].as_str(), + part["functionCall"]["args"].as_object(), + ) { + tool_calls.push(ToolCall::new(name.to_string(), json!(args), None)); + } + } } + + let text = text_parts.join("\n\n"); if text.is_empty() && tool_calls.is_empty() { if let Some("SAFETY") = data["promptFeedback"]["blockReason"] .as_str() .or_else(|| data["candidates"][0]["finishReason"].as_str()) { - bail!("Content Blocked") + bail!("Blocked due to safety") } else { bail!("Invalid response data: {data}"); } } let output = ChatCompletionsOutput { - text: text.to_string(), + text, tool_calls, id: None, input_tokens: data["usageMetadata"]["promptTokenCount"].as_u64(),