Skip to content

Commit 3aafc80

Browse files
Implemented a web search tool, provided by Anthropic
This web search tool is categorized as a `server tool`. related doc: https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/web-search-tool Signed-off-by: jonghoonpark <[email protected]>
1 parent 3919204 commit 3aafc80

File tree

13 files changed

+552
-47
lines changed

13 files changed

+552
-47
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.Map;
2424
import java.util.Set;
2525
import java.util.stream.Collectors;
26+
import java.util.stream.Stream;
2627

2728
import com.fasterxml.jackson.core.type.TypeReference;
2829
import io.micrometer.observation.Observation;
@@ -42,6 +43,7 @@
4243
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Source;
4344
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type;
4445
import org.springframework.ai.anthropic.api.AnthropicApi.Role;
46+
import org.springframework.ai.anthropic.api.tool.Tool;
4547
import org.springframework.ai.chat.messages.AssistantMessage;
4648
import org.springframework.ai.chat.messages.MessageType;
4749
import org.springframework.ai.chat.messages.ToolResponseMessage;
@@ -342,11 +344,11 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage
342344
return new ChatResponse(generations, this.from(chatCompletion, usage));
343345
}
344346

345-
private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) {
347+
private ChatResponseMetadata from(ChatCompletionResponse result) {
346348
return from(result, this.getDefaultUsage(result.usage()));
347349
}
348350

349-
private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result, Usage usage) {
351+
private ChatResponseMetadata from(ChatCompletionResponse result, Usage usage) {
350352
Assert.notNull(result, "Anthropic ChatCompletionResult must not be null");
351353
return ChatResponseMetadata.builder()
352354
.id(result.id())
@@ -443,20 +445,32 @@ Prompt buildRequestPrompt(Prompt prompt) {
443445
this.defaultOptions.getToolCallbacks()));
444446
requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(),
445447
this.defaultOptions.getToolContext()));
448+
requestOptions.setServerTools(
449+
mergeServerTools(runtimeOptions.getServerTools(), this.defaultOptions.getServerTools()));
446450
}
447451
else {
448452
requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders());
449453
requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled());
450454
requestOptions.setToolNames(this.defaultOptions.getToolNames());
451455
requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks());
452456
requestOptions.setToolContext(this.defaultOptions.getToolContext());
457+
requestOptions.setServerTools(this.defaultOptions.getServerTools());
453458
}
454459

455460
ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks());
456461

457462
return new Prompt(prompt.getInstructions(), requestOptions);
458463
}
459464

465+
static List<Tool> mergeServerTools(List<Tool> runtimeServerTools, List<Tool> defaultToolNames) {
466+
Assert.notNull(runtimeServerTools, "runtimeServerTools cannot be null");
467+
Assert.notNull(defaultToolNames, "defaultToolNames cannot be null");
468+
if (CollectionUtils.isEmpty(runtimeServerTools)) {
469+
return new ArrayList<>(defaultToolNames);
470+
}
471+
return new ArrayList<>(runtimeServerTools);
472+
}
473+
460474
private Map<String, String> mergeHttpHeaders(Map<String, String> runtimeHttpHeaders,
461475
Map<String, String> defaultHttpHeaders) {
462476
var mergedHttpHeaders = new HashMap<>(defaultHttpHeaders);
@@ -526,22 +540,31 @@ else if (message.getMessageType() == MessageType.TOOL) {
526540

527541
// Add the tool definitions to the request's tools parameter.
528542
List<ToolDefinition> toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions);
543+
Stream<Tool> toolStream = Stream.empty();
529544
if (!CollectionUtils.isEmpty(toolDefinitions)) {
530545
request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class);
531-
request = ChatCompletionRequest.from(request).tools(getFunctionTools(toolDefinitions)).build();
546+
toolStream = getFunctionToolStream(toolDefinitions);
547+
}
548+
if (!CollectionUtils.isEmpty(requestOptions.getServerTools())) {
549+
toolStream = Stream.concat(toolStream, requestOptions.getServerTools().stream());
550+
}
551+
552+
List<Tool> tools = toolStream.toList();
553+
if (!tools.isEmpty()) {
554+
request = ChatCompletionRequest.from(request).tools(tools).build();
532555
}
533556

534557
return request;
535558
}
536559

537-
private List<AnthropicApi.Tool> getFunctionTools(List<ToolDefinition> toolDefinitions) {
560+
private Stream<Tool> getFunctionToolStream(List<ToolDefinition> toolDefinitions) {
538561
return toolDefinitions.stream().map(toolDefinition -> {
539562
var name = toolDefinition.name();
540563
var description = toolDefinition.description();
541564
String inputSchema = toolDefinition.inputSchema();
542-
return new AnthropicApi.Tool(name, description, JsonParser.fromJson(inputSchema, new TypeReference<>() {
565+
return new Tool(name, description, JsonParser.fromJson(inputSchema, new TypeReference<>() {
543566
}));
544-
}).toList();
567+
});
545568
}
546569

547570
@Override

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
import org.springframework.ai.anthropic.api.AnthropicApi;
3434
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest;
35+
import org.springframework.ai.anthropic.api.tool.Tool;
3536
import org.springframework.ai.model.tool.ToolCallingChatOptions;
3637
import org.springframework.ai.tool.ToolCallback;
3738
import org.springframework.lang.Nullable;
@@ -44,6 +45,7 @@
4445
* @author Thomas Vitale
4546
* @author Alexandros Pappas
4647
* @author Ilayaperumal Gopinathan
48+
* @author Jonghoon Park
4749
* @since 1.0.0
4850
*/
4951
@JsonInclude(Include.NON_NULL)
@@ -82,6 +84,8 @@ public class AnthropicChatOptions implements ToolCallingChatOptions {
8284
@JsonIgnore
8385
private Map<String, Object> toolContext = new HashMap<>();
8486

87+
@JsonIgnore
88+
private List<Tool> serverTools = new ArrayList<>();
8589

8690
/**
8791
* Optional HTTP headers to be added to the chat completion request.
@@ -110,6 +114,7 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions)
110114
.toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null)
111115
.internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled())
112116
.toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null)
117+
.serverTools(fromOptions.getServerTools() != null ? new ArrayList<>(fromOptions.getServerTools()) : null)
113118
.httpHeaders(fromOptions.getHttpHeaders() != null ? new HashMap<>(fromOptions.getHttpHeaders()) : null)
114119
.build();
115120
}
@@ -250,6 +255,17 @@ public void setToolContext(Map<String, Object> toolContext) {
250255
this.toolContext = toolContext;
251256
}
252257

258+
@JsonIgnore
259+
public List<Tool> getServerTools() {
260+
return this.serverTools;
261+
}
262+
263+
public void setServerTools(List<Tool> serverTools) {
264+
Assert.notNull(serverTools, "serverTools cannot be null");
265+
Assert.noNullElements(serverTools, "serverTools cannot contain null elements");
266+
this.serverTools = serverTools;
267+
}
268+
253269
@JsonIgnore
254270
public Map<String, String> getHttpHeaders() {
255271
return this.httpHeaders;
@@ -282,14 +298,15 @@ public boolean equals(Object o) {
282298
&& Objects.equals(this.toolNames, that.toolNames)
283299
&& Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled)
284300
&& Objects.equals(this.toolContext, that.toolContext)
301+
&& Objects.equals(this.serverTools, that.serverTools)
285302
&& Objects.equals(this.httpHeaders, that.httpHeaders);
286303
}
287304

288305
@Override
289306
public int hashCode() {
290307
return Objects.hash(this.model, this.maxTokens, this.metadata, this.stopSequences, this.temperature, this.topP,
291308
this.topK, this.thinking, this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled,
292-
this.toolContext, this.httpHeaders);
309+
this.toolContext, this.serverTools, this.httpHeaders);
293310
}
294311

295312
public static class Builder {
@@ -384,6 +401,16 @@ public Builder toolContext(Map<String, Object> toolContext) {
384401
return this;
385402
}
386403

404+
public Builder serverTools(List<Tool> serverTools) {
405+
if (this.options.serverTools == null) {
406+
this.options.serverTools = serverTools;
407+
}
408+
else {
409+
this.options.serverTools.addAll(serverTools);
410+
}
411+
return this;
412+
}
413+
387414
public Builder httpHeaders(Map<String, String> httpHeaders) {
388415
this.options.setHttpHeaders(httpHeaders);
389416
return this;

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java

Lines changed: 61 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,15 @@
3030
import com.fasterxml.jackson.annotation.JsonProperty;
3131
import com.fasterxml.jackson.annotation.JsonSubTypes;
3232
import com.fasterxml.jackson.annotation.JsonTypeInfo;
33+
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
34+
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
3335
import reactor.core.publisher.Flux;
3436
import reactor.core.publisher.Mono;
3537

3638
import org.springframework.ai.anthropic.api.StreamHelper.ChatCompletionResponseBuilder;
39+
import org.springframework.ai.anthropic.api.tool.Tool;
40+
import org.springframework.ai.anthropic.util.ContentFieldDeserializer;
41+
import org.springframework.ai.anthropic.util.ContentFieldSerializer;
3742
import org.springframework.ai.model.ChatModelDescription;
3843
import org.springframework.ai.model.ModelOptionsUtils;
3944
import org.springframework.ai.observation.conventions.AiProvider;
@@ -717,7 +722,11 @@ public record ContentBlock(
717722

718723
// tool_result response only
719724
@JsonProperty("tool_use_id") String toolUseId,
720-
@JsonProperty("content") String content,
725+
726+
@JsonSerialize(using = ContentFieldSerializer.class)
727+
@JsonDeserialize(using = ContentFieldDeserializer.class)
728+
@JsonProperty("content")
729+
Object content,
721730

722731
// Thinking only
723732
@JsonProperty("signature") String signature,
@@ -728,6 +737,15 @@ public record ContentBlock(
728737
) {
729738
// @formatter:on
730739

740+
@JsonInclude(Include.NON_NULL)
741+
@JsonIgnoreProperties(ignoreUnknown = true)
742+
public record WebSearchToolContentBlock(@JsonProperty("type") String type, @JsonProperty("title") String title,
743+
@JsonProperty("url") String url, @JsonProperty("encrypted_content") String EncryptedContent,
744+
@JsonProperty("page_age") String pageAge) {
745+
746+
}
747+
// @formatter:on
748+
731749
/**
732750
* Create content block
733751
* @param mediaType The media type of the content.
@@ -813,6 +831,18 @@ public enum Type {
813831
@JsonProperty("tool_result")
814832
TOOL_RESULT("tool_result"),
815833

834+
/**
835+
* Server Tool request
836+
*/
837+
@JsonProperty("server_tool_use")
838+
SERVER_TOOL_USE("server_tool_use"),
839+
840+
/**
841+
* Web search tool result
842+
*/
843+
@JsonProperty("web_search_tool_result")
844+
WEB_SEARCH_TOOL_RESULT("web_search_tool_result"),
845+
816846
/**
817847
* Text message.
818848
*/
@@ -926,22 +956,6 @@ public Source(String url) {
926956
/// CONTENT_BLOCK EVENTS
927957
///////////////////////////////////////
928958

929-
/**
930-
* Tool description.
931-
*
932-
* @param name The name of the tool.
933-
* @param description A description of the tool.
934-
* @param inputSchema The input schema of the tool.
935-
*/
936-
@JsonInclude(Include.NON_NULL)
937-
public record Tool(
938-
// @formatter:off
939-
@JsonProperty("name") String name,
940-
@JsonProperty("description") String description,
941-
@JsonProperty("input_schema") Map<String, Object> inputSchema) {
942-
// @formatter:on
943-
}
944-
945959
// CB START EVENT
946960

947961
/**
@@ -987,16 +1001,25 @@ public record ChatCompletionResponse(
9871001
public record Usage(
9881002
// @formatter:off
9891003
@JsonProperty("input_tokens") Integer inputTokens,
990-
@JsonProperty("output_tokens") Integer outputTokens) {
991-
// @formatter:off
1004+
@JsonProperty("output_tokens") Integer outputTokens,
1005+
@JsonProperty("server_tool_use") ServerToolUse serverToolUse) {
1006+
// @formatter:on
1007+
}
1008+
1009+
@JsonInclude(Include.NON_NULL)
1010+
@JsonIgnoreProperties(ignoreUnknown = true)
1011+
public record ServerToolUse(
1012+
// @formatter:off
1013+
@JsonProperty("web_search_requests") Integer webSearchRequests) {
1014+
// @formatter:on
9921015
}
9931016

994-
/// ECB STOP
1017+
/// ECB STOP
9951018

9961019
/**
9971020
* Special event used to aggregate multiple tool use events into a single event with
9981021
* list of aggregated ContentBlockToolUse.
999-
*/
1022+
*/
10001023
public static class ToolUseAggregationEvent implements StreamEvent {
10011024

10021025
private Integer index;
@@ -1015,17 +1038,17 @@ public EventType type() {
10151038
}
10161039

10171040
/**
1018-
* Get tool content blocks.
1019-
* @return The tool content blocks.
1020-
*/
1041+
* Get tool content blocks.
1042+
* @return The tool content blocks.
1043+
*/
10211044
public List<ContentBlockStartEvent.ContentBlockToolUse> getToolContentBlocks() {
10221045
return this.toolContentBlocks;
10231046
}
10241047

10251048
/**
1026-
* Check if the event is empty.
1027-
* @return True if the event is empty, false otherwise.
1028-
*/
1049+
* Check if the event is empty.
1050+
* @return True if the event is empty, false otherwise.
1051+
*/
10291052
public boolean isEmpty() {
10301053
return (this.index == null || this.id == null || this.name == null
10311054
|| !StringUtils.hasText(this.partialJson));
@@ -1054,7 +1077,8 @@ ToolUseAggregationEvent appendPartialJson(String partialJson) {
10541077
void squashIntoContentBlock() {
10551078
Map<String, Object> map = (StringUtils.hasText(this.partialJson))
10561079
? ModelOptionsUtils.jsonToMap(this.partialJson) : Map.of();
1057-
this.toolContentBlocks.add(new ContentBlockStartEvent.ContentBlockToolUse("tool_use", this.id, this.name, map));
1080+
this.toolContentBlocks
1081+
.add(new ContentBlockStartEvent.ContentBlockToolUse("tool_use", this.id, this.name, map));
10581082
this.index = null;
10591083
this.id = null;
10601084
this.name = null;
@@ -1063,28 +1087,29 @@ void squashIntoContentBlock() {
10631087

10641088
@Override
10651089
public String toString() {
1066-
return "EventToolUseBuilder [index=" + this.index + ", id=" + this.id + ", name=" + this.name + ", partialJson="
1067-
+ this.partialJson + ", toolUseMap=" + this.toolContentBlocks + "]";
1090+
return "EventToolUseBuilder [index=" + this.index + ", id=" + this.id + ", name=" + this.name
1091+
+ ", partialJson=" + this.partialJson + ", toolUseMap=" + this.toolContentBlocks + "]";
10681092
}
10691093

10701094
}
10711095

1072-
///////////////////////////////////////
1073-
/// MESSAGE EVENTS
1074-
///////////////////////////////////////
1096+
///////////////////////////////////////
1097+
/// MESSAGE EVENTS
1098+
///////////////////////////////////////
10751099

1076-
// MESSAGE START EVENT
1100+
// MESSAGE START EVENT
10771101

10781102
/**
10791103
* Content block start event.
1104+
*
10801105
* @param type The event type.
10811106
* @param index The index of the content block.
10821107
* @param contentBlock The content block body.
1083-
*/
1108+
*/
10841109
@JsonInclude(Include.NON_NULL)
10851110
@JsonIgnoreProperties(ignoreUnknown = true)
10861111
public record ContentBlockStartEvent(
1087-
// @formatter:off
1112+
// @formatter:off
10881113
@JsonProperty("type") EventType type,
10891114
@JsonProperty("index") Integer index,
10901115
@JsonProperty("content_block") ContentBlockBody contentBlock) implements StreamEvent {

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/StreamHelper.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -174,7 +174,7 @@ else if (event.type().equals(EventType.MESSAGE_DELTA)) {
174174

175175
if (messageDeltaEvent.usage() != null) {
176176
var totalUsage = new Usage(contentBlockReference.get().usage.inputTokens(),
177-
messageDeltaEvent.usage().outputTokens());
177+
messageDeltaEvent.usage().outputTokens(), contentBlockReference.get().usage.serverToolUse());
178178
contentBlockReference.get().withUsage(totalUsage);
179179
}
180180
}

0 commit comments

Comments
 (0)