Skip to content

feat(anthropic): Implemented a web search tool, provided by Anthropic #3353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import com.fasterxml.jackson.core.type.TypeReference;
import io.micrometer.observation.Observation;
Expand All @@ -42,6 +43,7 @@
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Source;
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type;
import org.springframework.ai.anthropic.api.AnthropicApi.Role;
import org.springframework.ai.anthropic.api.tool.Tool;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.ToolResponseMessage;
Expand Down Expand Up @@ -342,11 +344,11 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage
return new ChatResponse(generations, this.from(chatCompletion, usage));
}

private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) {
private ChatResponseMetadata from(ChatCompletionResponse result) {
return from(result, this.getDefaultUsage(result.usage()));
}

private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result, Usage usage) {
private ChatResponseMetadata from(ChatCompletionResponse result, Usage usage) {
Assert.notNull(result, "Anthropic ChatCompletionResult must not be null");
return ChatResponseMetadata.builder()
.id(result.id())
Expand Down Expand Up @@ -443,20 +445,32 @@ Prompt buildRequestPrompt(Prompt prompt) {
this.defaultOptions.getToolCallbacks()));
requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(),
this.defaultOptions.getToolContext()));
requestOptions.setServerTools(
mergeServerTools(runtimeOptions.getServerTools(), this.defaultOptions.getServerTools()));
}
else {
requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders());
requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled());
requestOptions.setToolNames(this.defaultOptions.getToolNames());
requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks());
requestOptions.setToolContext(this.defaultOptions.getToolContext());
requestOptions.setServerTools(this.defaultOptions.getServerTools());
}

ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks());

return new Prompt(prompt.getInstructions(), requestOptions);
}

static List<Tool> mergeServerTools(List<Tool> runtimeServerTools, List<Tool> defaultToolNames) {
Assert.notNull(runtimeServerTools, "runtimeServerTools cannot be null");
Assert.notNull(defaultToolNames, "defaultToolNames cannot be null");
if (CollectionUtils.isEmpty(runtimeServerTools)) {
return new ArrayList<>(defaultToolNames);
}
return new ArrayList<>(runtimeServerTools);
}

private Map<String, String> mergeHttpHeaders(Map<String, String> runtimeHttpHeaders,
Map<String, String> defaultHttpHeaders) {
var mergedHttpHeaders = new HashMap<>(defaultHttpHeaders);
Expand Down Expand Up @@ -526,22 +540,31 @@ else if (message.getMessageType() == MessageType.TOOL) {

// Add the tool definitions to the request's tools parameter.
List<ToolDefinition> toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions);
Stream<Tool> toolStream = Stream.empty();
if (!CollectionUtils.isEmpty(toolDefinitions)) {
request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class);
Copy link
Contributor Author

@dev-jonghoonpark dev-jonghoonpark Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class);

I'm not sure why this code is needed or what it does.
I kept it to preserve the existing logic.

If I can understand why this is necessary, I think I can improve the code further.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dev-jonghoonpark Good point. This line makes sure to update the ChatCompletionRequest with any of the missing chat options which can be set via default chat options.

Though we use the default options to create the ChatCompletion Request:

ChatCompletionRequest request = new ChatCompletionRequest(this.defaultOptions.getModel(), userMessages,
				systemPrompt, this.defaultOptions.getMaxTokens(), this.defaultOptions.getTemperature(), stream);

not all the properties are updated in this and the model options utils helps merging any of the missing ones.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ilayaperumalg Thank you for the explanation.
However, why is it called only when !CollectionUtils.isEmpty(toolDefinitions) is true?

request = ChatCompletionRequest.from(request).tools(getFunctionTools(toolDefinitions)).build();
toolStream = getFunctionToolStream(toolDefinitions);
}
if (!CollectionUtils.isEmpty(requestOptions.getServerTools())) {
toolStream = Stream.concat(toolStream, requestOptions.getServerTools().stream());
}

List<Tool> tools = toolStream.toList();
if (!tools.isEmpty()) {
request = ChatCompletionRequest.from(request).tools(tools).build();
}

return request;
}

private List<AnthropicApi.Tool> getFunctionTools(List<ToolDefinition> toolDefinitions) {
private Stream<Tool> getFunctionToolStream(List<ToolDefinition> toolDefinitions) {
return toolDefinitions.stream().map(toolDefinition -> {
var name = toolDefinition.name();
var description = toolDefinition.description();
String inputSchema = toolDefinition.inputSchema();
return new AnthropicApi.Tool(name, description, JsonParser.fromJson(inputSchema, new TypeReference<>() {
return new Tool(name, description, JsonParser.fromJson(inputSchema, new TypeReference<>() {
}));
}).toList();
});
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import org.springframework.ai.anthropic.api.AnthropicApi;
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest;
import org.springframework.ai.anthropic.api.tool.Tool;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.lang.Nullable;
Expand All @@ -44,6 +45,7 @@
* @author Thomas Vitale
* @author Alexandros Pappas
* @author Ilayaperumal Gopinathan
* @author Jonghoon Park
* @since 1.0.0
*/
@JsonInclude(Include.NON_NULL)
Expand Down Expand Up @@ -82,6 +84,8 @@ public class AnthropicChatOptions implements ToolCallingChatOptions {
@JsonIgnore
private Map<String, Object> toolContext = new HashMap<>();

@JsonIgnore
private List<Tool> serverTools = new ArrayList<>();

/**
* Optional HTTP headers to be added to the chat completion request.
Expand Down Expand Up @@ -110,6 +114,7 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions)
.toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null)
.internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled())
.toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null)
.serverTools(fromOptions.getServerTools() != null ? new ArrayList<>(fromOptions.getServerTools()) : null)
.httpHeaders(fromOptions.getHttpHeaders() != null ? new HashMap<>(fromOptions.getHttpHeaders()) : null)
.build();
}
Expand Down Expand Up @@ -250,6 +255,17 @@ public void setToolContext(Map<String, Object> toolContext) {
this.toolContext = toolContext;
}

@JsonIgnore
public List<Tool> getServerTools() {
return this.serverTools;
}

public void setServerTools(List<Tool> serverTools) {
Assert.notNull(serverTools, "serverTools cannot be null");
Assert.noNullElements(serverTools, "serverTools cannot contain null elements");
this.serverTools = serverTools;
}

@JsonIgnore
public Map<String, String> getHttpHeaders() {
return this.httpHeaders;
Expand Down Expand Up @@ -282,14 +298,15 @@ public boolean equals(Object o) {
&& Objects.equals(this.toolNames, that.toolNames)
&& Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled)
&& Objects.equals(this.toolContext, that.toolContext)
&& Objects.equals(this.serverTools, that.serverTools)
&& Objects.equals(this.httpHeaders, that.httpHeaders);
}

@Override
public int hashCode() {
return Objects.hash(this.model, this.maxTokens, this.metadata, this.stopSequences, this.temperature, this.topP,
this.topK, this.thinking, this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled,
this.toolContext, this.httpHeaders);
this.toolContext, this.serverTools, this.httpHeaders);
}

public static class Builder {
Expand Down Expand Up @@ -384,6 +401,16 @@ public Builder toolContext(Map<String, Object> toolContext) {
return this;
}

public Builder serverTools(List<Tool> serverTools) {
if (this.options.serverTools == null) {
this.options.serverTools = serverTools;
}
else {
this.options.serverTools.addAll(serverTools);
}
return this;
}

public Builder httpHeaders(Map<String, String> httpHeaders) {
this.options.setHttpHeaders(httpHeaders);
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,16 @@
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import org.springframework.ai.anthropic.api.StreamHelper.ChatCompletionResponseBuilder;
import org.springframework.ai.model.ApiKey;
import org.springframework.ai.anthropic.api.tool.Tool;
import org.springframework.ai.anthropic.util.ContentFieldDeserializer;
import org.springframework.ai.anthropic.util.ContentFieldSerializer;
import org.springframework.ai.model.ChatModelDescription;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.SimpleApiKey;
Expand Down Expand Up @@ -750,7 +755,11 @@ public record ContentBlock(

// tool_result response only
@JsonProperty("tool_use_id") String toolUseId,
@JsonProperty("content") String content,

@JsonSerialize(using = ContentFieldSerializer.class)
@JsonDeserialize(using = ContentFieldDeserializer.class)
@JsonProperty("content")
Object content,

// Thinking only
@JsonProperty("signature") String signature,
Expand All @@ -761,6 +770,15 @@ public record ContentBlock(
) {
// @formatter:on

@JsonInclude(Include.NON_NULL)
@JsonIgnoreProperties(ignoreUnknown = true)
public record WebSearchToolContentBlock(@JsonProperty("type") String type, @JsonProperty("title") String title,
@JsonProperty("url") String url, @JsonProperty("encrypted_content") String EncryptedContent,
@JsonProperty("page_age") String pageAge) {

}
// @formatter:on

/**
* Create content block
* @param mediaType The media type of the content.
Expand Down Expand Up @@ -846,6 +864,18 @@ public enum Type {
@JsonProperty("tool_result")
TOOL_RESULT("tool_result"),

/**
* Server Tool request
*/
@JsonProperty("server_tool_use")
SERVER_TOOL_USE("server_tool_use"),

/**
* Web search tool result
*/
@JsonProperty("web_search_tool_result")
WEB_SEARCH_TOOL_RESULT("web_search_tool_result"),

/**
* Text message.
*/
Expand Down Expand Up @@ -959,22 +989,6 @@ public Source(String url) {
/// CONTENT_BLOCK EVENTS
///////////////////////////////////////

/**
* Tool description.
*
* @param name The name of the tool.
* @param description A description of the tool.
* @param inputSchema The input schema of the tool.
*/
@JsonInclude(Include.NON_NULL)
public record Tool(
// @formatter:off
@JsonProperty("name") String name,
@JsonProperty("description") String description,
@JsonProperty("input_schema") Map<String, Object> inputSchema) {
// @formatter:on
}

// CB START EVENT

/**
Expand Down Expand Up @@ -1020,16 +1034,25 @@ public record ChatCompletionResponse(
public record Usage(
// @formatter:off
@JsonProperty("input_tokens") Integer inputTokens,
@JsonProperty("output_tokens") Integer outputTokens) {
// @formatter:off
@JsonProperty("output_tokens") Integer outputTokens,
@JsonProperty("server_tool_use") ServerToolUse serverToolUse) {
// @formatter:on
}

@JsonInclude(Include.NON_NULL)
@JsonIgnoreProperties(ignoreUnknown = true)
public record ServerToolUse(
// @formatter:off
@JsonProperty("web_search_requests") Integer webSearchRequests) {
// @formatter:on
}

/// ECB STOP
/// ECB STOP

/**
* Special event used to aggregate multiple tool use events into a single event with
* list of aggregated ContentBlockToolUse.
*/
*/
public static class ToolUseAggregationEvent implements StreamEvent {

private Integer index;
Expand All @@ -1048,17 +1071,17 @@ public EventType type() {
}

/**
* Get tool content blocks.
* @return The tool content blocks.
*/
* Get tool content blocks.
* @return The tool content blocks.
*/
public List<ContentBlockStartEvent.ContentBlockToolUse> getToolContentBlocks() {
return this.toolContentBlocks;
}

/**
* Check if the event is empty.
* @return True if the event is empty, false otherwise.
*/
* Check if the event is empty.
* @return True if the event is empty, false otherwise.
*/
public boolean isEmpty() {
return (this.index == null || this.id == null || this.name == null
|| !StringUtils.hasText(this.partialJson));
Expand Down Expand Up @@ -1087,7 +1110,8 @@ ToolUseAggregationEvent appendPartialJson(String partialJson) {
void squashIntoContentBlock() {
Map<String, Object> map = (StringUtils.hasText(this.partialJson))
? ModelOptionsUtils.jsonToMap(this.partialJson) : Map.of();
this.toolContentBlocks.add(new ContentBlockStartEvent.ContentBlockToolUse("tool_use", this.id, this.name, map));
this.toolContentBlocks
.add(new ContentBlockStartEvent.ContentBlockToolUse("tool_use", this.id, this.name, map));
this.index = null;
this.id = null;
this.name = null;
Expand All @@ -1096,28 +1120,29 @@ void squashIntoContentBlock() {

@Override
public String toString() {
return "EventToolUseBuilder [index=" + this.index + ", id=" + this.id + ", name=" + this.name + ", partialJson="
+ this.partialJson + ", toolUseMap=" + this.toolContentBlocks + "]";
return "EventToolUseBuilder [index=" + this.index + ", id=" + this.id + ", name=" + this.name
+ ", partialJson=" + this.partialJson + ", toolUseMap=" + this.toolContentBlocks + "]";
}

}

///////////////////////////////////////
/// MESSAGE EVENTS
///////////////////////////////////////
///////////////////////////////////////
/// MESSAGE EVENTS
///////////////////////////////////////

// MESSAGE START EVENT
// MESSAGE START EVENT

/**
* Content block start event.
*
* @param type The event type.
* @param index The index of the content block.
* @param contentBlock The content block body.
*/
*/
@JsonInclude(Include.NON_NULL)
@JsonIgnoreProperties(ignoreUnknown = true)
public record ContentBlockStartEvent(
// @formatter:off
// @formatter:off
@JsonProperty("type") EventType type,
@JsonProperty("index") Integer index,
@JsonProperty("content_block") ContentBlockBody contentBlock) implements StreamEvent {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -174,7 +174,7 @@ else if (event.type().equals(EventType.MESSAGE_DELTA)) {

if (messageDeltaEvent.usage() != null) {
var totalUsage = new Usage(contentBlockReference.get().usage.inputTokens(),
messageDeltaEvent.usage().outputTokens());
messageDeltaEvent.usage().outputTokens(), contentBlockReference.get().usage.serverToolUse());
contentBlockReference.get().withUsage(totalUsage);
}
}
Expand Down
Loading