diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index 270f3bef43d..6799a3a8551 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -23,6 +23,7 @@ import java.util.Map; import java.util.Set; import java.util.stream.Collectors; +import java.util.stream.Stream; import com.fasterxml.jackson.core.type.TypeReference; import io.micrometer.observation.Observation; @@ -42,6 +43,7 @@ import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Source; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type; import org.springframework.ai.anthropic.api.AnthropicApi.Role; +import org.springframework.ai.anthropic.api.tool.Tool; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.ToolResponseMessage; @@ -342,11 +344,11 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage return new ChatResponse(generations, this.from(chatCompletion, usage)); } - private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) { + private ChatResponseMetadata from(ChatCompletionResponse result) { return from(result, this.getDefaultUsage(result.usage())); } - private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result, Usage usage) { + private ChatResponseMetadata from(ChatCompletionResponse result, Usage usage) { Assert.notNull(result, "Anthropic ChatCompletionResult must not be null"); return ChatResponseMetadata.builder() .id(result.id()) @@ -443,6 +445,8 @@ Prompt buildRequestPrompt(Prompt prompt) { this.defaultOptions.getToolCallbacks())); requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(), this.defaultOptions.getToolContext())); + requestOptions.setServerTools( + mergeServerTools(runtimeOptions.getServerTools(), this.defaultOptions.getServerTools())); } else { requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders()); @@ -450,6 +454,7 @@ Prompt buildRequestPrompt(Prompt prompt) { requestOptions.setToolNames(this.defaultOptions.getToolNames()); requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); requestOptions.setToolContext(this.defaultOptions.getToolContext()); + requestOptions.setServerTools(this.defaultOptions.getServerTools()); } ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks()); @@ -457,6 +462,15 @@ Prompt buildRequestPrompt(Prompt prompt) { return new Prompt(prompt.getInstructions(), requestOptions); } + static List mergeServerTools(List runtimeServerTools, List defaultToolNames) { + Assert.notNull(runtimeServerTools, "runtimeServerTools cannot be null"); + Assert.notNull(defaultToolNames, "defaultToolNames cannot be null"); + if (CollectionUtils.isEmpty(runtimeServerTools)) { + return new ArrayList<>(defaultToolNames); + } + return new ArrayList<>(runtimeServerTools); + } + private Map mergeHttpHeaders(Map runtimeHttpHeaders, Map defaultHttpHeaders) { var mergedHttpHeaders = new HashMap<>(defaultHttpHeaders); @@ -526,22 +540,31 @@ else if (message.getMessageType() == MessageType.TOOL) { // Add the tool definitions to the request's tools parameter. List toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions); + Stream toolStream = Stream.empty(); if (!CollectionUtils.isEmpty(toolDefinitions)) { request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class); - request = ChatCompletionRequest.from(request).tools(getFunctionTools(toolDefinitions)).build(); + toolStream = getFunctionToolStream(toolDefinitions); + } + if (!CollectionUtils.isEmpty(requestOptions.getServerTools())) { + toolStream = Stream.concat(toolStream, requestOptions.getServerTools().stream()); + } + + List tools = toolStream.toList(); + if (!tools.isEmpty()) { + request = ChatCompletionRequest.from(request).tools(tools).build(); } return request; } - private List getFunctionTools(List toolDefinitions) { + private Stream getFunctionToolStream(List toolDefinitions) { return toolDefinitions.stream().map(toolDefinition -> { var name = toolDefinition.name(); var description = toolDefinition.description(); String inputSchema = toolDefinition.inputSchema(); - return new AnthropicApi.Tool(name, description, JsonParser.fromJson(inputSchema, new TypeReference<>() { + return new Tool(name, description, JsonParser.fromJson(inputSchema, new TypeReference<>() { })); - }).toList(); + }); } @Override diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java index dbfbee561c8..3275ea61574 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java @@ -32,6 +32,7 @@ import org.springframework.ai.anthropic.api.AnthropicApi; import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest; +import org.springframework.ai.anthropic.api.tool.Tool; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.tool.ToolCallback; import org.springframework.lang.Nullable; @@ -44,6 +45,7 @@ * @author Thomas Vitale * @author Alexandros Pappas * @author Ilayaperumal Gopinathan + * @author Jonghoon Park * @since 1.0.0 */ @JsonInclude(Include.NON_NULL) @@ -82,6 +84,8 @@ public class AnthropicChatOptions implements ToolCallingChatOptions { @JsonIgnore private Map toolContext = new HashMap<>(); + @JsonIgnore + private List serverTools = new ArrayList<>(); /** * Optional HTTP headers to be added to the chat completion request. @@ -110,6 +114,7 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) .toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) + .serverTools(fromOptions.getServerTools() != null ? new ArrayList<>(fromOptions.getServerTools()) : null) .httpHeaders(fromOptions.getHttpHeaders() != null ? new HashMap<>(fromOptions.getHttpHeaders()) : null) .build(); } @@ -250,6 +255,17 @@ public void setToolContext(Map toolContext) { this.toolContext = toolContext; } + @JsonIgnore + public List getServerTools() { + return this.serverTools; + } + + public void setServerTools(List serverTools) { + Assert.notNull(serverTools, "serverTools cannot be null"); + Assert.noNullElements(serverTools, "serverTools cannot contain null elements"); + this.serverTools = serverTools; + } + @JsonIgnore public Map getHttpHeaders() { return this.httpHeaders; @@ -282,6 +298,7 @@ public boolean equals(Object o) { && Objects.equals(this.toolNames, that.toolNames) && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled) && Objects.equals(this.toolContext, that.toolContext) + && Objects.equals(this.serverTools, that.serverTools) && Objects.equals(this.httpHeaders, that.httpHeaders); } @@ -289,7 +306,7 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(this.model, this.maxTokens, this.metadata, this.stopSequences, this.temperature, this.topP, this.topK, this.thinking, this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, - this.toolContext, this.httpHeaders); + this.toolContext, this.serverTools, this.httpHeaders); } public static class Builder { @@ -384,6 +401,16 @@ public Builder toolContext(Map toolContext) { return this; } + public Builder serverTools(List serverTools) { + if (this.options.serverTools == null) { + this.options.serverTools = serverTools; + } + else { + this.options.serverTools.addAll(serverTools); + } + return this; + } + public Builder httpHeaders(Map httpHeaders) { this.options.setHttpHeaders(httpHeaders); return this; diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java index e0a0a1bbf5f..67853d567f8 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java @@ -30,11 +30,16 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonSubTypes; import com.fasterxml.jackson.annotation.JsonTypeInfo; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import org.springframework.ai.anthropic.api.StreamHelper.ChatCompletionResponseBuilder; import org.springframework.ai.model.ApiKey; +import org.springframework.ai.anthropic.api.tool.Tool; +import org.springframework.ai.anthropic.util.ContentFieldDeserializer; +import org.springframework.ai.anthropic.util.ContentFieldSerializer; import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.SimpleApiKey; @@ -750,7 +755,11 @@ public record ContentBlock( // tool_result response only @JsonProperty("tool_use_id") String toolUseId, - @JsonProperty("content") String content, + + @JsonSerialize(using = ContentFieldSerializer.class) + @JsonDeserialize(using = ContentFieldDeserializer.class) + @JsonProperty("content") + Object content, // Thinking only @JsonProperty("signature") String signature, @@ -761,6 +770,15 @@ public record ContentBlock( ) { // @formatter:on + @JsonInclude(Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record WebSearchToolContentBlock(@JsonProperty("type") String type, @JsonProperty("title") String title, + @JsonProperty("url") String url, @JsonProperty("encrypted_content") String EncryptedContent, + @JsonProperty("page_age") String pageAge) { + + } + // @formatter:on + /** * Create content block * @param mediaType The media type of the content. @@ -846,6 +864,18 @@ public enum Type { @JsonProperty("tool_result") TOOL_RESULT("tool_result"), + /** + * Server Tool request + */ + @JsonProperty("server_tool_use") + SERVER_TOOL_USE("server_tool_use"), + + /** + * Web search tool result + */ + @JsonProperty("web_search_tool_result") + WEB_SEARCH_TOOL_RESULT("web_search_tool_result"), + /** * Text message. */ @@ -959,22 +989,6 @@ public Source(String url) { /// CONTENT_BLOCK EVENTS /////////////////////////////////////// - /** - * Tool description. - * - * @param name The name of the tool. - * @param description A description of the tool. - * @param inputSchema The input schema of the tool. - */ - @JsonInclude(Include.NON_NULL) - public record Tool( - // @formatter:off - @JsonProperty("name") String name, - @JsonProperty("description") String description, - @JsonProperty("input_schema") Map inputSchema) { - // @formatter:on - } - // CB START EVENT /** @@ -1020,16 +1034,25 @@ public record ChatCompletionResponse( public record Usage( // @formatter:off @JsonProperty("input_tokens") Integer inputTokens, - @JsonProperty("output_tokens") Integer outputTokens) { - // @formatter:off + @JsonProperty("output_tokens") Integer outputTokens, + @JsonProperty("server_tool_use") ServerToolUse serverToolUse) { + // @formatter:on + } + + @JsonInclude(Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ServerToolUse( + // @formatter:off + @JsonProperty("web_search_requests") Integer webSearchRequests) { + // @formatter:on } - /// ECB STOP + /// ECB STOP /** * Special event used to aggregate multiple tool use events into a single event with * list of aggregated ContentBlockToolUse. - */ + */ public static class ToolUseAggregationEvent implements StreamEvent { private Integer index; @@ -1048,17 +1071,17 @@ public EventType type() { } /** - * Get tool content blocks. - * @return The tool content blocks. - */ + * Get tool content blocks. + * @return The tool content blocks. + */ public List getToolContentBlocks() { return this.toolContentBlocks; } /** - * Check if the event is empty. - * @return True if the event is empty, false otherwise. - */ + * Check if the event is empty. + * @return True if the event is empty, false otherwise. + */ public boolean isEmpty() { return (this.index == null || this.id == null || this.name == null || !StringUtils.hasText(this.partialJson)); @@ -1087,7 +1110,8 @@ ToolUseAggregationEvent appendPartialJson(String partialJson) { void squashIntoContentBlock() { Map map = (StringUtils.hasText(this.partialJson)) ? ModelOptionsUtils.jsonToMap(this.partialJson) : Map.of(); - this.toolContentBlocks.add(new ContentBlockStartEvent.ContentBlockToolUse("tool_use", this.id, this.name, map)); + this.toolContentBlocks + .add(new ContentBlockStartEvent.ContentBlockToolUse("tool_use", this.id, this.name, map)); this.index = null; this.id = null; this.name = null; @@ -1096,28 +1120,29 @@ void squashIntoContentBlock() { @Override public String toString() { - return "EventToolUseBuilder [index=" + this.index + ", id=" + this.id + ", name=" + this.name + ", partialJson=" - + this.partialJson + ", toolUseMap=" + this.toolContentBlocks + "]"; + return "EventToolUseBuilder [index=" + this.index + ", id=" + this.id + ", name=" + this.name + + ", partialJson=" + this.partialJson + ", toolUseMap=" + this.toolContentBlocks + "]"; } } - /////////////////////////////////////// - /// MESSAGE EVENTS - /////////////////////////////////////// + /////////////////////////////////////// + /// MESSAGE EVENTS + /////////////////////////////////////// - // MESSAGE START EVENT + // MESSAGE START EVENT /** * Content block start event. + * * @param type The event type. * @param index The index of the content block. * @param contentBlock The content block body. - */ + */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) public record ContentBlockStartEvent( - // @formatter:off + // @formatter:off @JsonProperty("type") EventType type, @JsonProperty("index") Integer index, @JsonProperty("content_block") ContentBlockBody contentBlock) implements StreamEvent { diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java index e08a9669085..2d10ee226d6 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -174,7 +174,7 @@ else if (event.type().equals(EventType.MESSAGE_DELTA)) { if (messageDeltaEvent.usage() != null) { var totalUsage = new Usage(contentBlockReference.get().usage.inputTokens(), - messageDeltaEvent.usage().outputTokens()); + messageDeltaEvent.usage().outputTokens(), contentBlockReference.get().usage.serverToolUse()); contentBlockReference.get().withUsage(totalUsage); } } diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/tool/Tool.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/tool/Tool.java new file mode 100644 index 00000000000..fd32c982daf --- /dev/null +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/tool/Tool.java @@ -0,0 +1,53 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.anthropic.api.tool; + +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * The Anthropic Tool Data Class. + * + * @author Jonghoon Park + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class Tool { + + @JsonProperty("name") + String name; + + @JsonProperty("description") + String description; + + @JsonProperty("input_schema") + Map inputSchema; + + /** + * Tool description. + * @param name The name of the tool. + * @param description A description of the tool. + * @param inputSchema The input schema of the tool. + */ + public Tool(String name, String description, Map inputSchema) { + this.name = name; + this.description = description; + this.inputSchema = inputSchema; + } + +} diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/tool/search/UserLocation.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/tool/search/UserLocation.java new file mode 100644 index 00000000000..2d60146e15e --- /dev/null +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/tool/search/UserLocation.java @@ -0,0 +1,37 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.anthropic.api.tool.search; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * implementation for [WebSearchTool](...) + * + * @author Jonghoon Park + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public record UserLocation( +// @formatter:off + @JsonProperty("type") String type, + @JsonProperty("city") String city, + @JsonProperty("region") String region, + @JsonProperty("country") String country, + @JsonProperty("timezone") String timezone) { + // @formatter:on +} diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/tool/search/WebSearchTool.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/tool/search/WebSearchTool.java new file mode 100644 index 00000000000..69c60a40bfc --- /dev/null +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/tool/search/WebSearchTool.java @@ -0,0 +1,125 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.anthropic.api.tool.search; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.springframework.ai.anthropic.api.tool.Tool; +import org.springframework.util.Assert; + +/** + * implementation for [WebSearchTool](...) + * + * @author Jonghoon Park + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class WebSearchTool extends Tool { + + public static Builder builder() { + return new Builder(); + } + + private static final String WEB_SEARCH = "web_search"; + + private static final String WEB_SEARCH_20250305 = "web_search_20250305"; + + @JsonProperty("type") + private final String type; + + @JsonProperty("max_uses") + private final Integer maxUses; + + @JsonProperty("allowed_domains") + private final List allowedDomains; + + @JsonProperty("blocked_domains") + private final List blockedDomains; + + @JsonProperty("user_location") + private final UserLocation userLocation; + + public WebSearchTool(String name, String type, Integer maxUses, List allowedDomains, + List blockedDomains, UserLocation userLocation) { + super(name, null, null); + this.type = type; + this.maxUses = maxUses; + this.allowedDomains = allowedDomains; + this.blockedDomains = blockedDomains; + this.userLocation = userLocation; + } + + public static class Builder { + + private String name = WEB_SEARCH; + + private String type = WEB_SEARCH_20250305; + + private Integer maxUses = 5; + + private List allowedDomains; + + private List blockedDomains; + + private UserLocation userLocation; + + public Builder name(String name) { + Assert.notNull(name, "name cannot be null"); + this.name = name; + return this; + } + + public Builder type(String type) { + Assert.notNull(type, "type cannot be null"); + this.type = type; + return this; + } + + public Builder maxUses(int maxUses) { + this.maxUses = maxUses; + return this; + } + + public Builder allowedDomains(List allowedDomains) { + Assert.notNull(allowedDomains, "allowedDomains cannot be null"); + this.allowedDomains = allowedDomains; + return this; + } + + public Builder blockedDomains(List blockedDomains) { + Assert.notNull(blockedDomains, "blockedDomains cannot be null"); + this.blockedDomains = blockedDomains; + return this; + } + + public Builder userLocation(UserLocation userLocation) { + Assert.notNull(userLocation, "userLocation cannot be null"); + this.userLocation = userLocation; + return this; + } + + public WebSearchTool build() { + return new WebSearchTool(this.name, this.type, this.maxUses, this.allowedDomains, this.blockedDomains, + this.userLocation); + } + + } + +} diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/util/ContentFieldDeserializer.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/util/ContentFieldDeserializer.java new file mode 100644 index 00000000000..c6feda2fde5 --- /dev/null +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/util/ContentFieldDeserializer.java @@ -0,0 +1,54 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.anthropic.util; + +import java.io.IOException; +import java.util.List; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.JsonNode; + +import org.springframework.ai.anthropic.api.AnthropicApi; + +/** + * Deserialize the multi-type field {@link AnthropicApi.ContentBlock#content()}. + * + * @author Jonghoon Park + */ +public class ContentFieldDeserializer extends JsonDeserializer { + + @Override + public Object deserialize(JsonParser parser, DeserializationContext deserializationContext) throws IOException { + JsonNode node = parser.getCodec().readTree(parser); + + if (node.isTextual()) { + return node.asText(); + } + else if (node.isArray()) { + return parser.getCodec() + .readValue(node.traverse(), + new TypeReference>() { + }); + } + + return null; + } + +} diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/util/ContentFieldSerializer.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/util/ContentFieldSerializer.java new file mode 100644 index 00000000000..9780418038a --- /dev/null +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/util/ContentFieldSerializer.java @@ -0,0 +1,52 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.anthropic.util; + +import java.io.IOException; +import java.util.List; + +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.SerializerProvider; + +import org.springframework.ai.anthropic.api.AnthropicApi; + +/** + * Serialize the multi-type field {@link AnthropicApi.ContentBlock#content()}. + * + * @author Jonghoon Park + */ +public class ContentFieldSerializer extends JsonSerializer { + + @Override + public void serialize(Object o, JsonGenerator generator, SerializerProvider serializerProvider) throws IOException { + if (o instanceof String content) { + generator.writeString(content); + } + else if (o instanceof List list) { + generator.writeStartArray(); + for (Object block : list) { + generator.writeObject(block); + } + generator.writeEndArray(); + } + else { + generator.writeNull(); + } + } + +} diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java index 35cf443866c..d259dfb690d 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java @@ -28,6 +28,7 @@ import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionResponse; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock; import org.springframework.ai.anthropic.api.AnthropicApi.Role; +import org.springframework.ai.anthropic.api.tool.Tool; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.http.ResponseEntity; @@ -44,7 +45,7 @@ public class AnthropicApiIT { AnthropicApi anthropicApi = AnthropicApi.builder().apiKey(System.getenv("ANTHROPIC_API_KEY")).build(); - List tools = List.of(new AnthropicApi.Tool("getCurrentWeather", + List tools = List.of(new Tool("getCurrentWeather", "Get the weather in location. Return temperature in 30°F or 30°C format.", ModelOptionsUtils.jsonToMap(""" { "type": "object", diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/AnthropicApiToolIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/AnthropicApiToolIT.java index d2f6d30f2ec..5c36190614e 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/AnthropicApiToolIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/AnthropicApiToolIT.java @@ -33,7 +33,6 @@ import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock; import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type; import org.springframework.ai.anthropic.api.AnthropicApi.Role; -import org.springframework.ai.anthropic.api.AnthropicApi.Tool; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.http.ResponseEntity; import org.springframework.util.CollectionUtils; diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/search/AnthropicSearchToolIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/search/AnthropicSearchToolIT.java new file mode 100644 index 00000000000..ae17bd12803 --- /dev/null +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/search/AnthropicSearchToolIT.java @@ -0,0 +1,75 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.anthropic.api.tool.search; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.anthropic.api.AnthropicApi; +import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionResponse; +import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock; +import org.springframework.http.ResponseEntity; + +import static org.assertj.core.api.Assertions.assertThat; + +@EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = ".+") +public class AnthropicSearchToolIT { + + private static final Logger logger = LoggerFactory.getLogger(AnthropicSearchToolIT.class); + + AnthropicApi anthropicApi = AnthropicApi.builder().apiKey(System.getenv("ANTHROPIC_API_KEY")).build(); + + @Test + void chatCompletionWithWebSearchTool() { + List messageConversation = new ArrayList<>(); + + AnthropicApi.AnthropicMessage chatCompletionMessage = new AnthropicApi.AnthropicMessage( + List.of(new AnthropicApi.ContentBlock( + "What's the weather like in San Francisco? Show the temperature in Celsius.")), + AnthropicApi.Role.USER); + + messageConversation.add(chatCompletionMessage); + + AnthropicApi.ChatCompletionRequest chatCompletionRequest = AnthropicApi.ChatCompletionRequest.builder() + .model(AnthropicApi.ChatModel.CLAUDE_3_5_HAIKU) + .messages(messageConversation) + .maxTokens(1024) + .tools(List.of(WebSearchTool.builder().maxUses(1).build())) + .build(); + + ResponseEntity response = this.anthropicApi.chatCompletionEntity(chatCompletionRequest); + + var responseText = response.getBody() + .content() + .stream() + .filter(contentBlock -> contentBlock.type() == ContentBlock.Type.TEXT) + .map(ContentBlock::text) + .map(str -> str.replace("\n", " ").trim()) + .collect(Collectors.joining()); + logger.info("RESPONSE: " + responseText); + + assertThat(response.getBody().usage().serverToolUse()).isNotNull(); + assertThat(response.getBody().usage().serverToolUse().webSearchRequests()).isGreaterThan(0); + } + +} diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java index 5bd91be7ebb..63a3df78431 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java @@ -35,10 +35,13 @@ import org.springframework.ai.anthropic.AnthropicTestConfiguration; import org.springframework.ai.anthropic.api.AnthropicApi; import org.springframework.ai.anthropic.api.tool.MockWeatherService; +import org.springframework.ai.anthropic.api.tool.search.WebSearchTool; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor; +import org.springframework.ai.chat.messages.AbstractMessage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.test.CurlyBracketEscaper; @@ -339,6 +342,37 @@ void streamingMultiModality() throws IOException { assertThat(content).containsAnyOf("bananas", "apple", "bowl", "basket", "fruit stand"); } + @Test + void webSearchToolCallTest() { + + // @formatter:off + ChatResponse chatResponse = ChatClient.builder(this.chatModel) + .defaultOptions(AnthropicChatOptions.builder() + .maxTokens(1024) + .serverTools(List.of(WebSearchTool.builder().maxUses(1).build())) + .build()) + .defaultUser(u -> u.text("What's the weather like in San Francisco? Show the temperature in Celsius.")) + .build() + .prompt() + .call() + .chatResponse(); + + AnthropicApi.Usage nativeUsage = (AnthropicApi.Usage) chatResponse.getMetadata().getUsage().getNativeUsage(); + // @formatter:on + + logger.info("nativeUsage: {}", nativeUsage); + logger.info("Response: {}", + chatResponse.getResults() + .stream() + .map(Generation::getOutput) + .map(AbstractMessage::getText) + .map(text -> text.replace("\n", "").trim()) + .collect(Collectors.joining(", "))); + + assertThat(nativeUsage.serverToolUse()).isNotNull(); + assertThat(nativeUsage.serverToolUse().webSearchRequests()).isGreaterThan(0); + } + record ActorsFilms(String actor, List movies) { }