diff --git a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsCommon.java b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsCommon.java index 730f53af2..cd5b4d7bf 100644 --- a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsCommon.java +++ b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsCommon.java @@ -19,7 +19,13 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.Part; +import java.util.Base64; import java.util.Map; +import org.jspecify.annotations.Nullable; /** Shared models for Chat Completions Request and Response. */ @JsonIgnoreProperties(ignoreUnknown = true) @@ -28,6 +34,17 @@ final class ChatCompletionsCommon { private ChatCompletionsCommon() {} + private static final ObjectMapper objectMapper = new ObjectMapper(); + + public static final String ROLE_ASSISTANT = "assistant"; + public static final String ROLE_MODEL = "model"; + + public static final String METADATA_KEY_ID = "id"; + public static final String METADATA_KEY_CREATED = "created"; + public static final String METADATA_KEY_OBJECT = "object"; + public static final String METADATA_KEY_SYSTEM_FINGERPRINT = "system_fingerprint"; + public static final String METADATA_KEY_SERVICE_TIER = "service_tier"; + /** * See * https://developers.openai.com/api/reference/resources/chat#(resource)%20chat.completions%20%3E%20(model)%20chat_completion_message_tool_call%20%3E%20(schema) @@ -56,6 +73,43 @@ static class ToolCall { */ @JsonProperty("extra_content") public Map extraContent; + + /** + * Converts the tool call to a {@link Part}. + * + * @return a {@link Part} containing the function call, or {@code null} if this tool call does + * not contain a function call. + */ + public @Nullable Part toPart() { + if (function != null) { + FunctionCall fc = function.toFunctionCall(id); + Part part = Part.builder().functionCall(fc).build(); + return applyThoughtSignature(part); + } + return null; + } + + /** + * Applies the thought signature from {@code extraContent} to the given {@link Part} if present. + * This is used to support the Google Gemini/Vertex AI implementation of the chat/completions + * API. + * + * @param part the {@link Part} to modify. + * @return a new {@link Part} with the thought signature applied, or the original {@link Part} + * if no thought signature is found. + */ + public Part applyThoughtSignature(Part part) { + if (extraContent != null && extraContent.containsKey("google")) { + Object googleObj = extraContent.get("google"); + if (googleObj instanceof Map googleMap) { + Object sigObj = googleMap.get("thought_signature"); + if (sigObj instanceof String sig) { + return part.toBuilder().thoughtSignature(Base64.getDecoder().decode(sig)).build(); + } + } + } + return part; + } } /** @@ -70,6 +124,33 @@ static class Function { /** See class definition for more details. */ public String arguments; // JSON string + + /** + * Converts this function to a {@link FunctionCall}. + * + * @param toolCallId the ID of the tool call, or {@code null} if not applicable. + * @return the {@link FunctionCall} object. + */ + public FunctionCall toFunctionCall(@Nullable String toolCallId) { + FunctionCall.Builder fcBuilder = FunctionCall.builder(); + if (name != null) { + fcBuilder.name(name); + } + if (arguments != null) { + try { + Map args = + objectMapper.readValue(arguments, new TypeReference>() {}); + fcBuilder.args(args); + } catch (Exception e) { + throw new IllegalArgumentException( + "Failed to parse function arguments JSON: " + arguments, e); + } + } + if (toolCallId != null) { + fcBuilder.id(toolCallId); + } + return fcBuilder.build(); + } } /** diff --git a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java index 75a96e6ee..c52389aa3 100644 --- a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java +++ b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java @@ -19,7 +19,16 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.adk.models.LlmResponse; +import com.google.genai.types.Content; +import com.google.genai.types.CustomMetadata; +import com.google.genai.types.FinishReason; +import com.google.genai.types.FinishReason.Known; +import com.google.genai.types.GenerateContentResponseUsageMetadata; +import com.google.genai.types.Part; +import java.util.ArrayList; import java.util.List; +import org.jspecify.annotations.Nullable; /** * Data Transfer Objects for Chat Completion and Chat Completion Chunk API responses. @@ -62,6 +71,162 @@ static class ChatCompletion { /** See class definition for more details. */ public Usage usage; + + /** + * Converts this chat completion to a {@link LlmResponse}. + * + * @return the {@link LlmResponse} object. + */ + public LlmResponse toLlmResponse() { + Choice choice = (choices != null && !choices.isEmpty()) ? choices.get(0) : null; + Content content = mapChoiceToContent(choice); + + LlmResponse.Builder builder = LlmResponse.builder().content(content); + + if (choice != null) { + builder.finishReason(mapFinishReason(choice.finishReason)); + } + + if (model != null) { + builder.modelVersion(model); + } + + if (usage != null) { + builder.usageMetadata(mapUsage(usage)); + } + + List customMetadataList = buildCustomMetadata(); + return builder.customMetadata(customMetadataList).build(); + } + + /** + * Maps the finish reason string to a {@link FinishReason}. + * + * @param reason the finish reason string. + * @return the {@link FinishReason}, or {@code null} if the input reason is null. + */ + private @Nullable FinishReason mapFinishReason(String reason) { + if (reason == null) { + return null; + } + return switch (reason) { + case "stop", "tool_calls" -> new FinishReason(Known.STOP.toString()); + case "length" -> new FinishReason(Known.MAX_TOKENS.toString()); + case "content_filter" -> new FinishReason(Known.SAFETY.toString()); + default -> new FinishReason(Known.OTHER.toString()); + }; + } + + private GenerateContentResponseUsageMetadata mapUsage(Usage usage) { + GenerateContentResponseUsageMetadata.Builder builder = + GenerateContentResponseUsageMetadata.builder(); + if (usage.promptTokens != null) { + builder.promptTokenCount(usage.promptTokens); + } + if (usage.completionTokens != null) { + builder.candidatesTokenCount(usage.completionTokens); + } + if (usage.totalTokens != null) { + builder.totalTokenCount(usage.totalTokens); + } + if (usage.thoughtsTokenCount != null) { + builder.thoughtsTokenCount(usage.thoughtsTokenCount); + } else if (usage.completionTokensDetails != null + && usage.completionTokensDetails.reasoningTokens != null) { + builder.thoughtsTokenCount(usage.completionTokensDetails.reasoningTokens); + } + return builder.build(); + } + + /** + * Maps the chosen completion to a {@link Content} object. + * + * @param choice the completion choice to map, or {@code null}. + * @return the {@link Content} object, which will be empty if the choice or its message is null. + */ + private Content mapChoiceToContent(@Nullable Choice choice) { + Content.Builder contentBuilder = Content.builder(); + if (choice != null && choice.message != null) { + contentBuilder.role(mapRole(choice.message.role)).parts(mapMessageToParts(choice.message)); + } + return contentBuilder.build(); + } + + private String mapRole(@Nullable String role) { + return (role != null && role.equals(ChatCompletionsCommon.ROLE_ASSISTANT)) + ? ChatCompletionsCommon.ROLE_MODEL + : role; + } + + private List mapMessageToParts(Message message) { + List parts = new ArrayList<>(); + if (message.content != null) { + parts.add(Part.fromText(message.content)); + } + if (message.refusal != null) { + parts.add(Part.fromText(message.refusal)); + } + if (message.toolCalls != null) { + parts.addAll(mapToolCallsToParts(message.toolCalls)); + } + return parts; + } + + private List mapToolCallsToParts(List toolCalls) { + List parts = new ArrayList<>(); + for (ChatCompletionsCommon.ToolCall toolCall : toolCalls) { + Part part = toolCall.toPart(); + if (part != null) { + parts.add(part); + } + } + return parts; + } + + /** + * Builds the list of custom metadata from the chat completion fields. + * + * @return a list of {@link CustomMetadata}, which will be empty if no relevant fields are set. + */ + private List buildCustomMetadata() { + List customMetadataList = new ArrayList<>(); + if (id != null) { + customMetadataList.add( + CustomMetadata.builder() + .key(ChatCompletionsCommon.METADATA_KEY_ID) + .stringValue(id) + .build()); + } + if (created != null) { + customMetadataList.add( + CustomMetadata.builder() + .key(ChatCompletionsCommon.METADATA_KEY_CREATED) + .stringValue(created.toString()) + .build()); + } + if (object != null) { + customMetadataList.add( + CustomMetadata.builder() + .key(ChatCompletionsCommon.METADATA_KEY_OBJECT) + .stringValue(object) + .build()); + } + if (systemFingerprint != null) { + customMetadataList.add( + CustomMetadata.builder() + .key(ChatCompletionsCommon.METADATA_KEY_SYSTEM_FINGERPRINT) + .stringValue(systemFingerprint) + .build()); + } + if (serviceTier != null) { + customMetadataList.add( + CustomMetadata.builder() + .key(ChatCompletionsCommon.METADATA_KEY_SERVICE_TIER) + .stringValue(serviceTier) + .build()); + } + return customMetadataList; + } } /** diff --git a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java index 52134476c..dd1a5d85a 100644 --- a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java +++ b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java @@ -19,8 +19,15 @@ import static com.google.common.truth.Truth.assertThat; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.models.LlmResponse; import com.google.adk.models.chat.ChatCompletionsResponse.ChatCompletion; import com.google.adk.models.chat.ChatCompletionsResponse.ChatCompletionChunk; +import com.google.genai.types.CustomMetadata; +import com.google.genai.types.FinishReason.Known; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.Part; +import java.util.Base64; +import java.util.List; import java.util.Map; import org.junit.Before; import org.junit.Test; @@ -135,7 +142,7 @@ public void testDeserializeChatCompletion_withThoughtSignatureAndGeminiTokens() assertThat(completion.choices.get(0).message.toolCalls).hasSize(1); assertThat(completion.choices.get(0).message.toolCalls.get(0).extraContent).isNotNull(); Map extraContentMap = - (Map) completion.choices.get(0).message.toolCalls.get(0).extraContent; + completion.choices.get(0).message.toolCalls.get(0).extraContent; @SuppressWarnings("unchecked") // This code won't run in production and it's is a JSON object. Map googleMap = (Map) extraContentMap.get("google"); assertThat(googleMap.get("thought_signature")).isEqualTo("c2lnbmF0dXJl"); @@ -168,8 +175,7 @@ public void testDeserializeChatCompletion_withArbitraryExtraContent() throws Exc ChatCompletion got = objectMapper.readValue(json, ChatCompletion.class); assertThat(got.choices.get(0).message.toolCalls).hasSize(1); - Map extraContent = - (Map) got.choices.get(0).message.toolCalls.get(0).extraContent; + Map extraContent = got.choices.get(0).message.toolCalls.get(0).extraContent; assertThat(extraContent.get("custom_key")).isEqualTo("custom_value"); @SuppressWarnings("unchecked") // This code won't run in production and it's is a JSON object. Map nested = (Map) extraContent.get("nested"); @@ -325,4 +331,360 @@ public void testDeserializeChatCompletionChunk_withToolCallDelta() throws Except assertThat(chunk.usage.promptTokens).isEqualTo(5); assertThat(chunk.usage.totalTokens).isEqualTo(15); } + + @Test + public void testToLlmResponse_simpleText() throws Exception { + String json = + """ + { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1694268190, + "model": "gpt-4", + "system_fingerprint": "fp_123", + "service_tier": "scale", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "Hello world" + }, + "finish_reason": "stop" + }], + "usage": { + "completion_tokens": 10, + "prompt_tokens": 5, + "total_tokens": 15, + "thoughts_token_count": 42 + } + } + """; + + ChatCompletionsResponse.ChatCompletion completion = + objectMapper.readValue(json, ChatCompletionsResponse.ChatCompletion.class); + + LlmResponse response = completion.toLlmResponse(); + + assertThat(response.modelVersion()).hasValue("gpt-4"); + assertThat(response.finishReason().get().knownEnum()).isEqualTo(Known.STOP); + + // Usage Metadata + assertThat(response.usageMetadata().get().promptTokenCount()).hasValue(5); + assertThat(response.usageMetadata().get().candidatesTokenCount()).hasValue(10); + assertThat(response.usageMetadata().get().totalTokenCount()).hasValue(15); + assertThat(response.usageMetadata().get().thoughtsTokenCount()).hasValue(42); + + // Content + assertThat(response.content().get().role()).hasValue("model"); + assertThat(response.content().get().parts().get().get(0).text()).hasValue("Hello world"); + + // Custom Metadata + List metadata = response.customMetadata().get(); + assertThat(metadata).hasSize(5); + assertThat(metadata.get(0).key()).hasValue("id"); + assertThat(metadata.get(0).stringValue()).hasValue("chatcmpl-123"); + assertThat(metadata.get(1).key()).hasValue("created"); + assertThat(metadata.get(1).stringValue()).hasValue("1694268190"); + assertThat(metadata.get(2).key()).hasValue("object"); + assertThat(metadata.get(2).stringValue()).hasValue("chat.completion"); + assertThat(metadata.get(3).key()).hasValue("system_fingerprint"); + assertThat(metadata.get(3).stringValue()).hasValue("fp_123"); + assertThat(metadata.get(4).key()).hasValue("service_tier"); + assertThat(metadata.get(4).stringValue()).hasValue("scale"); + } + + @Test + public void testToLlmResponse_userRole() throws Exception { + String json = + """ + { + "choices": [{ + "index": 0, + "message": { + "role": "user", + "content": "Hello world" + }, + "finish_reason": "stop" + }] + } + """; + + ChatCompletionsResponse.ChatCompletion completion = + objectMapper.readValue(json, ChatCompletionsResponse.ChatCompletion.class); + + LlmResponse response = completion.toLlmResponse(); + + assertThat(response.content().get().role()).hasValue("user"); + } + + @Test + public void testToLlmResponse_withToolCall_simple() throws Exception { + String json = + """ + { + "choices": [{ + "message": { + "role": "assistant", + "tool_calls": [{ + "id": "call_123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\\\"location\\\":\\\"Seattle\\\"}" + } + }] + } + }] + } + """; + + ChatCompletionsResponse.ChatCompletion completion = + objectMapper.readValue(json, ChatCompletion.class); + + LlmResponse response = completion.toLlmResponse(); + + Part part = response.content().get().parts().get().get(0); + FunctionCall fc = part.functionCall().get(); + assertThat(fc.id()).hasValue("call_123"); + assertThat(fc.name()).hasValue("get_weather"); + assertThat(fc.args().get().get("location")).isEqualTo("Seattle"); + + assertThat(response.customMetadata().get()).isEmpty(); + } + + @Test + public void testToLlmResponse_thoughtSignature() throws Exception { + String json = + """ + { + "choices": [{ + "message": { + "role": "assistant", + "tool_calls": [{ + "id": "call_123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\\\"location\\\":\\\"Seattle\\\"}" + }, + "extra_content": { + "google": { + "thought_signature": "c2ln" + } + } + }] + } + }] + } + """; + + ChatCompletionsResponse.ChatCompletion completion = + objectMapper.readValue(json, ChatCompletion.class); + + LlmResponse response = completion.toLlmResponse(); + + assertThat(response.content().get().parts().get().get(0).thoughtSignature().get()) + .isEqualTo(Base64.getDecoder().decode("c2ln")); + } + + @Test + public void testToLlmResponse_withRefusal() throws Exception { + String json = + """ + { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-3.5-turbo-0125", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "refusal": "System error or refusal" + }, + "finish_reason": "stop" + }] + } + """; + + ChatCompletionsResponse.ChatCompletion completion = + objectMapper.readValue(json, ChatCompletionsResponse.ChatCompletion.class); + + LlmResponse response = completion.toLlmResponse(); + + assertThat(response.modelVersion()).hasValue("gpt-3.5-turbo-0125"); + assertThat(response.finishReason().get().knownEnum()).isEqualTo(Known.STOP); + + // Content + assertThat(response.content().get().role()).hasValue("model"); + assertThat(response.content().get().parts().get().get(0).text()) + .hasValue("System error or refusal"); + + // Custom Metadata + List metadata = response.customMetadata().get(); + assertThat(metadata).hasSize(3); + assertThat(metadata.get(0).key()).hasValue("id"); + assertThat(metadata.get(0).stringValue()).hasValue("chatcmpl-123"); + assertThat(metadata.get(1).key()).hasValue("created"); + assertThat(metadata.get(1).stringValue()).hasValue("1677652288"); + assertThat(metadata.get(2).key()).hasValue("object"); + assertThat(metadata.get(2).stringValue()).hasValue("chat.completion"); + } + + @Test + public void testToLlmResponse_reasoningTokens() throws Exception { + String json = + """ + { + "choices": [{ + "message": { + "role": "assistant", + "content": "hello" + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + "completion_tokens_details": { + "reasoning_tokens": 4 + } + } + } + """; + + ChatCompletionsResponse.ChatCompletion completion = + objectMapper.readValue(json, ChatCompletionsResponse.ChatCompletion.class); + + LlmResponse response = completion.toLlmResponse(); + + assertThat(response.finishReason().get().knownEnum()).isEqualTo(Known.STOP); + + // Content + assertThat(response.content().get().role()).hasValue("model"); + assertThat(response.content().get().parts().get().get(0).text()).hasValue("hello"); + + // Usage Metadata + assertThat(response.usageMetadata().get().promptTokenCount()).hasValue(10); + assertThat(response.usageMetadata().get().candidatesTokenCount()).hasValue(5); + assertThat(response.usageMetadata().get().totalTokenCount()).hasValue(15); + assertThat(response.usageMetadata().get().thoughtsTokenCount()).hasValue(4); + + assertThat(response.customMetadata().get()).isEmpty(); + } + + @Test + public void testToolCallToPart_withFunction() throws Exception { + String json = + """ + { + "id": "call_123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\\\"location\\\":\\\"Seattle\\\"}" + } + } + """; + ChatCompletionsCommon.ToolCall toolCall = + objectMapper.readValue(json, ChatCompletionsCommon.ToolCall.class); + + Part part = toolCall.toPart(); + + assertThat(part).isNotNull(); + assertThat(part.functionCall()).isPresent(); + FunctionCall fc = part.functionCall().get(); + assertThat(fc.id()).hasValue("call_123"); + assertThat(fc.name()).hasValue("get_weather"); + } + + @Test + public void testToolCallToPart_withFunction_nullId() throws Exception { + String json = + """ + { + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\\\"location\\\":\\\"Seattle\\\"}" + } + } + """; + ChatCompletionsCommon.ToolCall toolCall = + objectMapper.readValue(json, ChatCompletionsCommon.ToolCall.class); + + Part part = toolCall.toPart(); + + assertThat(part).isNotNull(); + assertThat(part.functionCall()).isPresent(); + FunctionCall fc = part.functionCall().get(); + assertThat(fc.id()).isEmpty(); + } + + @Test + public void testToolCallToPart_withThoughtSignature() throws Exception { + String json = + """ + { + "id": "call_123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\\\"location\\\":\\\"Seattle\\\"}" + }, + "extra_content": { + "google": { + "thought_signature": "c2ln" + } + } + } + """; + ChatCompletionsCommon.ToolCall toolCall = + objectMapper.readValue(json, ChatCompletionsCommon.ToolCall.class); + + Part part = toolCall.toPart(); + + assertThat(part).isNotNull(); + assertThat(part.thoughtSignature().get()).isEqualTo(Base64.getDecoder().decode("c2ln")); + } + + @Test + public void testToolCallToPart_nullFunction() throws Exception { + String json = + """ + { + "id": "call_123", + "type": "function" + } + """; + ChatCompletionsCommon.ToolCall toolCall = + objectMapper.readValue(json, ChatCompletionsCommon.ToolCall.class); + + Part part = toolCall.toPart(); + + assertThat(part).isNull(); + } + + @Test + public void testToLlmResponse_noChoices() throws Exception { + String json = + """ + { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-4" + } + """; + + ChatCompletionsResponse.ChatCompletion completion = + objectMapper.readValue(json, ChatCompletionsResponse.ChatCompletion.class); + + LlmResponse response = completion.toLlmResponse(); + + assertThat(response.modelVersion()).hasValue("gpt-4"); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().parts()).isEmpty(); + } }