Skip to content

Add Bedrock Anthropic3 tool support. #810

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,28 @@
*/
package org.springframework.ai.bedrock.anthropic3;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;

import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
import org.springframework.util.Assert;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

/**
* @author Ben Middleton
* @since 1.0.0
*/
@JsonInclude(Include.NON_NULL)
public class Anthropic3ChatOptions implements ChatOptions {
public class Anthropic3ChatOptions implements ChatOptions, FunctionCallingOptions {

// @formatter:off
/**
Expand Down Expand Up @@ -66,6 +75,32 @@ public class Anthropic3ChatOptions implements ChatOptions {
* The version of the generative to use. The default value is bedrock-2023-05-31.
*/
private @JsonProperty("anthropic_version") String anthropicVersion;

/**
* Tool Function Callbacks to register with the ChatModel. For Prompt
* Options the functionCallbacks are automatically enabled for the duration of the
* prompt execution. For Default Options the functionCallbacks are registered but
* disabled by default. Use the enableFunctions to set the functions from the registry
* to be used by the ChatModel chat completion requests.
*/
@NestedConfigurationProperty
@JsonIgnore
private List<FunctionCallback> functionCallbacks = new ArrayList<>();

/**
* List of functions, identified by their names, to configure for function calling in
* the chat completion requests. Functions with those names must exist in the
* functionCallbacks registry. The {@link #functionCallbacks} from the PromptOptions
* are automatically enabled for the duration of the prompt execution.
*
* Note that function enabled with the default options are enabled for all chat
* completion requests. This could impact the token count and the billing. If the
* functions is set in a prompt options, then the enabled functions are only active
* for the duration of this prompt execution.
*/
@NestedConfigurationProperty
@JsonIgnore
private Set<String> functions = new HashSet<>();
// @formatter:on

public static Builder builder() {
Expand Down Expand Up @@ -106,6 +141,23 @@ public Builder withAnthropicVersion(String anthropicVersion) {
return this;
}

public Builder withFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
this.options.functionCallbacks = functionCallbacks;
return this;
}

public Builder withFunctions(Set<String> functionNames) {
Assert.notNull(functionNames, "Function names must not be null");
this.options.functions = functionNames;
return this;
}

public Builder withFunction(String functionName) {
Assert.hasText(functionName, "Function name must not be empty");
this.options.functions.add(functionName);
return this;
}

public Anthropic3ChatOptions build() {
return this.options;
}
Expand Down Expand Up @@ -163,13 +215,37 @@ public void setAnthropicVersion(String anthropicVersion) {
this.anthropicVersion = anthropicVersion;
}

@Override
public List<FunctionCallback> getFunctionCallbacks() {
return this.functionCallbacks;
}

@Override
public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null");
this.functionCallbacks = functionCallbacks;
}

@Override
public Set<String> getFunctions() {
return this.functions;
}

@Override
public void setFunctions(Set<String> functions) {
Assert.notNull(functions, "Function must not be null");
this.functions = functions;
}

public static Anthropic3ChatOptions fromOptions(Anthropic3ChatOptions fromOptions) {
return builder().withTemperature(fromOptions.getTemperature())
.withMaxTokens(fromOptions.getMaxTokens())
.withTopK(fromOptions.getTopK())
.withTopP(fromOptions.getTopP())
.withStopSequences(fromOptions.getStopSequences())
.withAnthropicVersion(fromOptions.getAnthropicVersion())
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
.withFunctions(fromOptions.getFunctions())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

import java.util.ArrayList;
import java.util.Base64;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;

Expand All @@ -30,6 +32,7 @@
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage;
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage.Role;
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.MediaContent;
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.MediaContent.Type;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
Expand All @@ -40,6 +43,8 @@
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.util.CollectionUtils;

/**
Expand All @@ -48,9 +53,12 @@
*
* @author Ben Middleton
* @author Christian Tzolov
* @author Wei Jiang
* @since 1.0.0
*/
public class BedrockAnthropic3ChatModel implements ChatModel, StreamingChatModel {
public class BedrockAnthropic3ChatModel extends
AbstractFunctionCallSupport<Anthropic3ChatBedrockApi.ChatCompletionMessage, Anthropic3ChatBedrockApi.AnthropicChatRequest, Anthropic3ChatBedrockApi.AnthropicChatResponse>
implements ChatModel, StreamingChatModel {

private final Anthropic3ChatBedrockApi anthropicChatApi;

Expand All @@ -67,6 +75,13 @@ public BedrockAnthropic3ChatModel(Anthropic3ChatBedrockApi chatApi) {
}

public BedrockAnthropic3ChatModel(Anthropic3ChatBedrockApi chatApi, Anthropic3ChatOptions options) {
this(chatApi, options, null);
}

public BedrockAnthropic3ChatModel(Anthropic3ChatBedrockApi chatApi, Anthropic3ChatOptions options,
FunctionCallbackContext functionCallbackContext) {
super(functionCallbackContext);

this.anthropicChatApi = chatApi;
this.defaultOptions = options;
}
Expand All @@ -76,7 +91,7 @@ public ChatResponse call(Prompt prompt) {

AnthropicChatRequest request = createRequest(prompt);

AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request);
AnthropicChatResponse response = this.callWithFunctionSupport(request);

return new ChatResponse(List.of(new Generation(response.content().get(0).text())));
}
Expand Down Expand Up @@ -117,22 +132,38 @@ AnthropicChatRequest createRequest(Prompt prompt) {
.withSystem(toAnthropicSystemContext(prompt))
.build();

Set<String> functionsForThisRequest = new HashSet<>();

if (this.defaultOptions != null) {
request = ModelOptionsUtils.merge(request, this.defaultOptions, AnthropicChatRequest.class);

Set<String> promptEnabledFunctions = this.handleFunctionCallbackConfigurations(this.defaultOptions,
!IS_RUNTIME_CALL);
functionsForThisRequest.addAll(promptEnabledFunctions);
}

if (prompt.getOptions() != null) {
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
Anthropic3ChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
ChatOptions.class, Anthropic3ChatOptions.class);
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, AnthropicChatRequest.class);

Set<String> defaultEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions,
IS_RUNTIME_CALL);
functionsForThisRequest.addAll(defaultEnabledFunctions);
}
else {
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
+ prompt.getOptions().getClass().getSimpleName());
}
}

if (!CollectionUtils.isEmpty(functionsForThisRequest)) {
List<Anthropic3ChatBedrockApi.Tool> tools = getFunctionTools(functionsForThisRequest);

request = AnthropicChatRequest.from(request).withTools(tools).build();
}

return request;
}

Expand Down Expand Up @@ -187,6 +218,80 @@ else if (mediaData instanceof String text) {
}
}

private List<Anthropic3ChatBedrockApi.Tool> getFunctionTools(Set<String> functionNames) {
return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
var description = functionCallback.getDescription();
var name = functionCallback.getName();
String inputSchema = functionCallback.getInputTypeSchema();
return new Anthropic3ChatBedrockApi.Tool(name, description, ModelOptionsUtils.jsonToMap(inputSchema));
}).toList();
}

@Override
protected AnthropicChatRequest doCreateToolResponseRequest(AnthropicChatRequest previousRequest,
ChatCompletionMessage responseMessage, List<ChatCompletionMessage> conversationHistory) {

List<MediaContent> toolToUseList = responseMessage.content()
.stream()
.filter(c -> c.type() == MediaContent.Type.TOOL_USE)
.toList();

List<MediaContent> toolResults = new ArrayList<>();

for (MediaContent toolToUse : toolToUseList) {

var functionCallId = toolToUse.id();
var functionName = toolToUse.name();
var functionArguments = toolToUse.input();

if (!this.functionCallbackRegister.containsKey(functionName)) {
throw new IllegalStateException("No function callback found for function name: " + functionName);
}

String functionResponse = this.functionCallbackRegister.get(functionName)
.call(ModelOptionsUtils.toJsonString(functionArguments));

toolResults.add(new MediaContent(Type.TOOL_RESULT, functionCallId, functionResponse));
}

// Add the function response to the conversation.
conversationHistory.add(new ChatCompletionMessage(toolResults, Role.USER));

// Recursively call chatCompletionWithTools until the model doesn't call a
// functions anymore.
return AnthropicChatRequest.from(previousRequest).withMessages(conversationHistory).build();
}

@Override
protected List<ChatCompletionMessage> doGetUserMessages(AnthropicChatRequest request) {
return request.messages();
}

@Override
protected ChatCompletionMessage doGetToolResponseMessage(AnthropicChatResponse response) {
return new ChatCompletionMessage(response.content(), Role.ASSISTANT);
}

@Override
protected AnthropicChatResponse doChatCompletion(AnthropicChatRequest request) {
return this.anthropicChatApi.chatCompletion(request);
}

@Override
protected Flux<AnthropicChatResponse> doChatCompletionStream(AnthropicChatRequest request) {
// https://docs.anthropic.com/en/docs/tool-use
throw new UnsupportedOperationException(
"Streaming (stream=true) is not yet supported. We plan to add streaming support in a future beta version.");
}

@Override
protected boolean isToolFunctionCall(AnthropicChatResponse response) {
if (response == null || CollectionUtils.isEmpty(response.content())) {
return false;
}
return response.content().stream().anyMatch(content -> content.type() == MediaContent.Type.TOOL_USE);
}

@Override
public ChatOptions getDefaultOptions() {
return Anthropic3ChatOptions.fromOptions(this.defaultOptions);
Expand Down
Loading