From 5253e65214785f25f84ad731e825515e6033fc3a Mon Sep 17 00:00:00 2001 From: wmz7year Date: Sun, 2 Jun 2024 14:54:16 +0800 Subject: [PATCH] Add Bedrock Anthropic3 tool support. --- .../anthropic3/Anthropic3ChatOptions.java | 78 +++++++- .../BedrockAnthropic3ChatModel.java | 109 +++++++++- .../api/Anthropic3ChatBedrockApi.java | 114 ++++++++++- .../BedrockAnthropic3ChatModelIT.java | 28 +++ .../tool/Anthropic3ChatBedrockApiToolIT.java | 160 +++++++++++++++ .../api/tool/MockWeatherService.java | 89 +++++++++ .../src/main/antora/modules/ROOT/nav.adoc | 1 + .../functions/anthropic-chat-functions.adoc | 2 +- .../bedrock-anthropic3-chat-functions.adoc | 188 ++++++++++++++++++ .../modules/ROOT/pages/api/functions.adoc | 3 +- 10 files changed, 757 insertions(+), 15 deletions(-) create mode 100644 models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/api/tool/Anthropic3ChatBedrockApiToolIT.java create mode 100644 models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/api/tool/MockWeatherService.java create mode 100644 spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/bedrock/bedrock-anthropic3-chat-functions.adoc diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java index b4995683a3b..e2286da2478 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java @@ -15,19 +15,28 @@ */ package org.springframework.ai.bedrock.anthropic3; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; + import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.boot.context.properties.NestedConfigurationProperty; +import org.springframework.util.Assert; +import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Set; /** * @author Ben Middleton * @since 1.0.0 */ @JsonInclude(Include.NON_NULL) -public class Anthropic3ChatOptions implements ChatOptions { +public class Anthropic3ChatOptions implements ChatOptions, FunctionCallingOptions { // @formatter:off /** @@ -66,6 +75,32 @@ public class Anthropic3ChatOptions implements ChatOptions { * The version of the generative to use. The default value is bedrock-2023-05-31. */ private @JsonProperty("anthropic_version") String anthropicVersion; + + /** + * Tool Function Callbacks to register with the ChatModel. For Prompt + * Options the functionCallbacks are automatically enabled for the duration of the + * prompt execution. For Default Options the functionCallbacks are registered but + * disabled by default. Use the enableFunctions to set the functions from the registry + * to be used by the ChatModel chat completion requests. + */ + @NestedConfigurationProperty + @JsonIgnore + private List functionCallbacks = new ArrayList<>(); + + /** + * List of functions, identified by their names, to configure for function calling in + * the chat completion requests. Functions with those names must exist in the + * functionCallbacks registry. The {@link #functionCallbacks} from the PromptOptions + * are automatically enabled for the duration of the prompt execution. + * + * Note that function enabled with the default options are enabled for all chat + * completion requests. This could impact the token count and the billing. If the + * functions is set in a prompt options, then the enabled functions are only active + * for the duration of this prompt execution. + */ + @NestedConfigurationProperty + @JsonIgnore + private Set functions = new HashSet<>(); // @formatter:on public static Builder builder() { @@ -106,6 +141,23 @@ public Builder withAnthropicVersion(String anthropicVersion) { return this; } + public Builder withFunctionCallbacks(List functionCallbacks) { + this.options.functionCallbacks = functionCallbacks; + return this; + } + + public Builder withFunctions(Set functionNames) { + Assert.notNull(functionNames, "Function names must not be null"); + this.options.functions = functionNames; + return this; + } + + public Builder withFunction(String functionName) { + Assert.hasText(functionName, "Function name must not be empty"); + this.options.functions.add(functionName); + return this; + } + public Anthropic3ChatOptions build() { return this.options; } @@ -163,6 +215,28 @@ public void setAnthropicVersion(String anthropicVersion) { this.anthropicVersion = anthropicVersion; } + @Override + public List getFunctionCallbacks() { + return this.functionCallbacks; + } + + @Override + public void setFunctionCallbacks(List functionCallbacks) { + Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null"); + this.functionCallbacks = functionCallbacks; + } + + @Override + public Set getFunctions() { + return this.functions; + } + + @Override + public void setFunctions(Set functions) { + Assert.notNull(functions, "Function must not be null"); + this.functions = functions; + } + public static Anthropic3ChatOptions fromOptions(Anthropic3ChatOptions fromOptions) { return builder().withTemperature(fromOptions.getTemperature()) .withMaxTokens(fromOptions.getMaxTokens()) @@ -170,6 +244,8 @@ public static Anthropic3ChatOptions fromOptions(Anthropic3ChatOptions fromOption .withTopP(fromOptions.getTopP()) .withStopSequences(fromOptions.getStopSequences()) .withAnthropicVersion(fromOptions.getAnthropicVersion()) + .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) + .withFunctions(fromOptions.getFunctions()) .build(); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java index e84ad252042..d7edaff112f 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.java @@ -17,7 +17,9 @@ import java.util.ArrayList; import java.util.Base64; +import java.util.HashSet; import java.util.List; +import java.util.Set; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; @@ -30,6 +32,7 @@ import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage.Role; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.MediaContent; +import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.MediaContent.Type; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; @@ -40,6 +43,8 @@ import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.function.AbstractFunctionCallSupport; +import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.util.CollectionUtils; /** @@ -48,9 +53,12 @@ * * @author Ben Middleton * @author Christian Tzolov + * @author Wei Jiang * @since 1.0.0 */ -public class BedrockAnthropic3ChatModel implements ChatModel, StreamingChatModel { +public class BedrockAnthropic3ChatModel extends + AbstractFunctionCallSupport + implements ChatModel, StreamingChatModel { private final Anthropic3ChatBedrockApi anthropicChatApi; @@ -67,6 +75,13 @@ public BedrockAnthropic3ChatModel(Anthropic3ChatBedrockApi chatApi) { } public BedrockAnthropic3ChatModel(Anthropic3ChatBedrockApi chatApi, Anthropic3ChatOptions options) { + this(chatApi, options, null); + } + + public BedrockAnthropic3ChatModel(Anthropic3ChatBedrockApi chatApi, Anthropic3ChatOptions options, + FunctionCallbackContext functionCallbackContext) { + super(functionCallbackContext); + this.anthropicChatApi = chatApi; this.defaultOptions = options; } @@ -76,7 +91,7 @@ public ChatResponse call(Prompt prompt) { AnthropicChatRequest request = createRequest(prompt); - AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request); + AnthropicChatResponse response = this.callWithFunctionSupport(request); return new ChatResponse(List.of(new Generation(response.content().get(0).text()))); } @@ -117,8 +132,14 @@ AnthropicChatRequest createRequest(Prompt prompt) { .withSystem(toAnthropicSystemContext(prompt)) .build(); + Set functionsForThisRequest = new HashSet<>(); + if (this.defaultOptions != null) { request = ModelOptionsUtils.merge(request, this.defaultOptions, AnthropicChatRequest.class); + + Set promptEnabledFunctions = this.handleFunctionCallbackConfigurations(this.defaultOptions, + !IS_RUNTIME_CALL); + functionsForThisRequest.addAll(promptEnabledFunctions); } if (prompt.getOptions() != null) { @@ -126,6 +147,10 @@ AnthropicChatRequest createRequest(Prompt prompt) { Anthropic3ChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions, ChatOptions.class, Anthropic3ChatOptions.class); request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, AnthropicChatRequest.class); + + Set defaultEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions, + IS_RUNTIME_CALL); + functionsForThisRequest.addAll(defaultEnabledFunctions); } else { throw new IllegalArgumentException("Prompt options are not of type ChatOptions: " @@ -133,6 +158,12 @@ AnthropicChatRequest createRequest(Prompt prompt) { } } + if (!CollectionUtils.isEmpty(functionsForThisRequest)) { + List tools = getFunctionTools(functionsForThisRequest); + + request = AnthropicChatRequest.from(request).withTools(tools).build(); + } + return request; } @@ -187,6 +218,80 @@ else if (mediaData instanceof String text) { } } + private List getFunctionTools(Set functionNames) { + return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> { + var description = functionCallback.getDescription(); + var name = functionCallback.getName(); + String inputSchema = functionCallback.getInputTypeSchema(); + return new Anthropic3ChatBedrockApi.Tool(name, description, ModelOptionsUtils.jsonToMap(inputSchema)); + }).toList(); + } + + @Override + protected AnthropicChatRequest doCreateToolResponseRequest(AnthropicChatRequest previousRequest, + ChatCompletionMessage responseMessage, List conversationHistory) { + + List toolToUseList = responseMessage.content() + .stream() + .filter(c -> c.type() == MediaContent.Type.TOOL_USE) + .toList(); + + List toolResults = new ArrayList<>(); + + for (MediaContent toolToUse : toolToUseList) { + + var functionCallId = toolToUse.id(); + var functionName = toolToUse.name(); + var functionArguments = toolToUse.input(); + + if (!this.functionCallbackRegister.containsKey(functionName)) { + throw new IllegalStateException("No function callback found for function name: " + functionName); + } + + String functionResponse = this.functionCallbackRegister.get(functionName) + .call(ModelOptionsUtils.toJsonString(functionArguments)); + + toolResults.add(new MediaContent(Type.TOOL_RESULT, functionCallId, functionResponse)); + } + + // Add the function response to the conversation. + conversationHistory.add(new ChatCompletionMessage(toolResults, Role.USER)); + + // Recursively call chatCompletionWithTools until the model doesn't call a + // functions anymore. + return AnthropicChatRequest.from(previousRequest).withMessages(conversationHistory).build(); + } + + @Override + protected List doGetUserMessages(AnthropicChatRequest request) { + return request.messages(); + } + + @Override + protected ChatCompletionMessage doGetToolResponseMessage(AnthropicChatResponse response) { + return new ChatCompletionMessage(response.content(), Role.ASSISTANT); + } + + @Override + protected AnthropicChatResponse doChatCompletion(AnthropicChatRequest request) { + return this.anthropicChatApi.chatCompletion(request); + } + + @Override + protected Flux doChatCompletionStream(AnthropicChatRequest request) { + // https://docs.anthropic.com/en/docs/tool-use + throw new UnsupportedOperationException( + "Streaming (stream=true) is not yet supported. We plan to add streaming support in a future beta version."); + } + + @Override + protected boolean isToolFunctionCall(AnthropicChatResponse response) { + if (response == null || CollectionUtils.isEmpty(response.content())) { + return false; + } + return response.content().stream().anyMatch(content -> content.type() == MediaContent.Type.TOOL_USE); + } + @Override public ChatOptions getDefaultOptions() { return Anthropic3ChatOptions.fromOptions(this.defaultOptions); diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java index 8b5b29ed1e8..87eed8767db 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java @@ -31,6 +31,7 @@ import java.time.Duration; import java.util.List; +import java.util.Map; /** * Based on Bedrock's stopSequences, - @JsonProperty("anthropic_version") String anthropicVersion) { + @JsonProperty("anthropic_version") String anthropicVersion, + @JsonProperty("tools") List tools, + @JsonProperty("tool_choice") ToolChoice toolChoice) { + + /** + * Specifices how the model should use the provided tools. The model can use a specific tool, any available tool, + * or decide by itself. + * + * @param type The type of tool choice. Possibile values are any (use any available tool),auto (the model decides), + * andtool (use the specified tool). + * @param name The name of the tool to use. Required if you specify tool in the type field. + */ + @JsonInclude(Include.NON_NULL) + public record ToolChoice( + @JsonProperty("type") String type, + @JsonProperty("name") String name) { + } public static Builder builder(List messages) { return new Builder(messages); } + public static Builder from(AnthropicChatRequest request) { + return new Builder(request); + } + public static class Builder { - private final List messages; + private List messages; private String system; - private Float temperature;// = 0.7f; - private Integer maxTokens;// = 500; - private Integer topK;// = 10; + private Float temperature; + private Integer maxTokens; + private Integer topK; private Float topP; private List stopSequences; private String anthropicVersion; + private List tools; + private ToolChoice toolChoice; private Builder(List messages) { this.messages = messages; } + private Builder(AnthropicChatRequest request) { + this.messages = request.messages(); + this.system = request.system(); + this.temperature = request.temperature(); + this.maxTokens = request.maxTokens(); + this.topK = request.topK(); + this.topP = request.topP(); + this.stopSequences = request.stopSequences(); + this.anthropicVersion = request.anthropicVersion(); + this.tools = request.tools(); + this.toolChoice = request.toolChoice(); + } + + public Builder withMessages(List messages) { + this.messages = messages; + return this; + } + public Builder withSystem(String system) { this.system = system; return this; @@ -205,6 +249,16 @@ public Builder withAnthropicVersion(String anthropicVersion) { return this; } + public Builder withTools(List tools) { + this.tools = tools; + return this; + } + + public Builder withToolChoice(ToolChoice toolChoice) { + this.toolChoice = toolChoice; + return this; + } + public AnthropicChatRequest build() { return new AnthropicChatRequest( messages, @@ -214,7 +268,9 @@ public AnthropicChatRequest build() { topK, topP, stopSequences, - anthropicVersion + anthropicVersion, + tools, + toolChoice ); } } @@ -232,7 +288,18 @@ public record MediaContent( // @formatter:off @JsonProperty("type") Type type, @JsonProperty("source") Source source, @JsonProperty("text") String text, - @JsonProperty("index") Integer index // applicable only for streaming responses. + + // applicable only for streaming responses. + @JsonProperty("index") Integer index, + + // tool_use response only + @JsonProperty("id") String id, + @JsonProperty("name") String name, + @JsonProperty("input") Map input, + + // tool_result response only + @JsonProperty("tool_use_id") String toolUseId, + @JsonProperty("content") String content ) { // @formatter:on @@ -241,11 +308,16 @@ public MediaContent(String mediaType, String data) { } public MediaContent(Source source) { - this(Type.IMAGE, source, null, null); + this(Type.IMAGE, source, null, null, null, null, null, null, null); } public MediaContent(String text) { - this(Type.TEXT, null, text, null); + this(Type.TEXT, null, text, null, null, null, null, null, null); + } + + // Tool result + public MediaContent(Type type, String toolUseId, String content) { + this(type, null, null, null, null, null, null, toolUseId, content); } /** @@ -262,7 +334,17 @@ public enum Type { * Image message. */ @JsonProperty("image") - IMAGE + IMAGE, + /** + * Tool use message. + */ + @JsonProperty("tool_use") + TOOL_USE, + /** + * Send tool result back to LLM. + */ + @JsonProperty("tool_result") + TOOL_RESULT, } @@ -288,6 +370,18 @@ public Source(String mediaType, String data) { } } + /** + * Definitions of tools that the model may use. + * + * @param name The name of the tool. + * @param description The description of the tool. + * @param inputSchema The schema for the tool. + */ + @JsonInclude(Include.NON_NULL) + public record Tool(@JsonProperty("name") String name, @JsonProperty("description") String description, + @JsonProperty("input_schema") Map inputSchema) { + } + /** * Message comprising the conversation. * diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModelIT.java index 7e31a8f0174..550f7140568 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModelIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModelIT.java @@ -17,6 +17,7 @@ import java.io.IOException; import java.time.Duration; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -32,6 +33,7 @@ import software.amazon.awssdk.regions.Region; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi; +import org.springframework.ai.bedrock.anthropic3.api.tool.MockWeatherService; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.messages.AssistantMessage; @@ -44,6 +46,7 @@ import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; +import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; @@ -203,6 +206,31 @@ void beanStreamOutputConverterRecords() { assertThat(actorsFilms.movies()).hasSize(5); } + @Test + void functionCallTest() { + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo and Paris? Return the result in Celsius."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = Anthropic3ChatOptions.builder() + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("getCurrentWeather") + .withDescription("Get the weather in location. Return temperature in 36°F or 36°C format.") + .build())) + .build(); + + ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + + Generation generation = response.getResult(); + assertThat(generation.getOutput().getContent()).containsAnyOf("30.0", "30"); + assertThat(generation.getOutput().getContent()).containsAnyOf("10.0", "10"); + assertThat(generation.getOutput().getContent()).containsAnyOf("15.0", "15"); + } + @Test void multiModalityTest() throws IOException { diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/api/tool/Anthropic3ChatBedrockApiToolIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/api/tool/Anthropic3ChatBedrockApiToolIT.java new file mode 100644 index 00000000000..3326bd8dfd1 --- /dev/null +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/api/tool/Anthropic3ChatBedrockApiToolIT.java @@ -0,0 +1,160 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.anthropic3.api.tool; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi; +import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatModel; +import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatRequest; +import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatResponse; +import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage; +import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage.Role; +import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.MediaContent; +import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.MediaContent.Type; +import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.Tool; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.util.CollectionUtils; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.DEFAULT_ANTHROPIC_VERSION; + +/** + * @author Wei Jiang + * @since 1.0.0 + */ +@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") +@SuppressWarnings({ "rawtypes", "unchecked" }) +public class Anthropic3ChatBedrockApiToolIT { + + private static final Logger logger = LoggerFactory.getLogger(Anthropic3ChatBedrockApiToolIT.class); + + private Anthropic3ChatBedrockApi anthropicChatApi = new Anthropic3ChatBedrockApi( + AnthropicChatModel.CLAUDE_V3_SONNET.id(), EnvironmentVariableCredentialsProvider.create(), + Region.US_EAST_1.id(), new ObjectMapper(), Duration.ofMinutes(2)); + + public static final ConcurrentHashMap FUNCTIONS = new ConcurrentHashMap<>(); + + static { + FUNCTIONS.put("getCurrentWeather", new MockWeatherService()); + } + + List tools = List.of(new Tool("getCurrentWeather", + "Get the weather in location. Return temperature in 30°F or 30°C format.", ModelOptionsUtils.jsonToMap(""" + { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["C", "F"] + } + }, + "required": ["location", "unit"] + } + """))); + + @Test + void toolCalls() { + + List messageConversation = new ArrayList<>(); + + ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage(List.of(new MediaContent( + "What's the weather like in San Francisco, Tokyo, and Paris? Show the temperature in Celsius.")), + Role.USER); + + messageConversation.add(chatCompletionMessage); + + AnthropicChatResponse chatCompletion = doCall(messageConversation); + + var responseText = chatCompletion.content().get(0).text(); + logger.info("FINAL RESPONSE: " + responseText); + + assertThat(responseText).contains("15"); + assertThat(responseText).contains("10"); + assertThat(responseText).contains("30"); + } + + private AnthropicChatResponse doCall(List messageConversation) { + + AnthropicChatRequest chatCompletionRequest = AnthropicChatRequest.builder(messageConversation) + .withMaxTokens(1500) + .withTemperature(0.8f) + .withTools(tools) + .withAnthropicVersion(DEFAULT_ANTHROPIC_VERSION) + .build(); + + AnthropicChatResponse response = anthropicChatApi.chatCompletion(chatCompletionRequest); + + List toolToUseList = response.content() + .stream() + .filter(c -> c.type() == MediaContent.Type.TOOL_USE) + .toList(); + + if (CollectionUtils.isEmpty(toolToUseList)) { + return response; + } + // Add use tool message to the conversation history + messageConversation.add(new ChatCompletionMessage(response.content(), Role.ASSISTANT)); + + List toolResults = new ArrayList<>(); + + for (MediaContent toolToUse : toolToUseList) { + + var id = toolToUse.id(); + var name = toolToUse.name(); + var input = toolToUse.input(); + + logger.info("FunctionCalls from the LLM: " + name); + + MockWeatherService.Request request = ModelOptionsUtils.mapToClass(input, MockWeatherService.Request.class); + + logger.info("Resolved function request param: " + request); + + Object functionCallResponseData = FUNCTIONS.get(name).apply(request); + + String content = ModelOptionsUtils.toJsonString(functionCallResponseData); + + logger.info("Function response : " + content); + + toolResults.add(new MediaContent(Type.TOOL_RESULT, id, content)); + } + + // Add function response message to the conversation history + messageConversation.add(new ChatCompletionMessage(toolResults, Role.USER)); + + return doCall(messageConversation); + } + +} diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/api/tool/MockWeatherService.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/api/tool/MockWeatherService.java new file mode 100644 index 00000000000..b90cfa09507 --- /dev/null +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/api/tool/MockWeatherService.java @@ -0,0 +1,89 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.anthropic3.api.tool; + +import java.util.function.Function; + +import com.fasterxml.jackson.annotation.JsonClassDescription; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; + +/** + * @author Christian Tzolov + */ +public class MockWeatherService implements Function { + + /** + * Weather Function request. + */ + @JsonInclude(Include.NON_NULL) + @JsonClassDescription("Weather API request") + public record Request(@JsonProperty(required = true, + value = "location") @JsonPropertyDescription("The city and state e.g. San Francisco, CA") String location, + @JsonProperty(required = true, value = "unit") @JsonPropertyDescription("Temperature unit") Unit unit) { + } + + /** + * Temperature units. + */ + public enum Unit { + + /** + * Celsius. + */ + C("metric"), + /** + * Fahrenheit. + */ + F("imperial"); + + /** + * Human readable unit name. + */ + public final String unitName; + + private Unit(String text) { + this.unitName = text; + } + + } + + /** + * Weather Function response. + */ + public record Response(double temp, Unit unit) { + } + + @Override + public Response apply(Request request) { + + double temperature = 0; + if (request.location().contains("Paris")) { + temperature = 15; + } + else if (request.location().contains("Tokyo")) { + temperature = 10; + } + else if (request.location().contains("San Francisco")) { + temperature = 30; + } + + return new Response(temperature, Unit.C); + } + +} \ No newline at end of file diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index 120a8e5c8b2..9475388b43f 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -11,6 +11,7 @@ **** xref:api/chat/functions/azure-open-ai-chat-functions.adoc[Function Calling] *** xref:api/bedrock-chat.adoc[Amazon Bedrock] **** xref:api/chat/bedrock/bedrock-anthropic3.adoc[Anthropic3] +***** xref:api/chat/functions/bedrock/bedrock-anthropic3-chat-functions.adoc[Function Calling] **** xref:api/chat/bedrock/bedrock-anthropic.adoc[Anthropic2] **** xref:api/chat/bedrock/bedrock-llama.adoc[Llama] **** xref:api/chat/bedrock/bedrock-cohere.adoc[Cohere] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/anthropic-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/anthropic-chat-functions.adoc index fea58dc1d07..1e13b223fdb 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/anthropic-chat-functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/anthropic-chat-functions.adoc @@ -6,7 +6,7 @@ The `claude-3-opus`, `claude-3-sonnet` and `claude-3-haiku` link:https://docs.an The Anthropic API does not call the function directly; instead, the model generates JSON that you can use to call the function in your code and return the result back to the model to complete the conversation. -NOTE: As of April 4th, 2024, streaming is not yet supported for function calling and Tool use is not yet available on third-party platforms like Vertex AI or AWS Bedrock, but is coming soon. +NOTE: As of April 4th, 2024, streaming is not yet supported for function calling and Tool use is not yet available on third-party platforms like Vertex AI, but is coming soon. Spring AI provides flexible and user-friendly ways to register and call custom functions. In general, the custom functions need to provide a function `name`, `description`, and the function call `signature` (as JSON schema) to let the model know what arguments the function expects. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/bedrock/bedrock-anthropic3-chat-functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/bedrock/bedrock-anthropic3-chat-functions.adoc new file mode 100644 index 00000000000..8e34ddef5e9 --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/functions/bedrock/bedrock-anthropic3-chat-functions.adoc @@ -0,0 +1,188 @@ += Bedrock Anthropic 3 Function Calling + +You can register custom Java functions with the `BedrockAnthropic3ChatModel` and have the Bedrock Anthropic 3 models intelligently choose to output a JSON object containing arguments to call one or many of the registered functions. +This allows you to connect the LLM capabilities with external tools and APIs. + +The Bedrock Anthropic 3 API does not call the function directly; instead, the model generates JSON that you can use to call the function in your code and return the result back to the model to complete the conversation. + +Spring AI provides flexible and user-friendly ways to register and call custom functions. +In general, the custom functions need to provide a function `name`, `description`, and the function call `signature` (as JSON schema) to let the model know what arguments the function expects. +The `description` helps the model to understand when to call the function. + +As a developer, you need to implement a function that takes the function call arguments sent from the AI model, and respond with the result back to the model. +Your function can in turn invoke other 3rd party services to provide the results. + +Spring AI makes this as easy as defining a `@Bean` definition that returns a `java.util.Function` and supplying the bean name as an option when invoking the `ChatModel`. + +Under the hood, Spring wraps your POJO (the function) with the appropriate adapter code that enables interaction with the AI Model, saving you from writing tedious boilerplate code. +The basis of the underlying infrastructure is the link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java[FunctionCallback.java] interface and the companion link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackWrapper.java[FunctionCallbackWrapper.java] utility class to simplify the implementation and registration of Java callback functions. + +== How it works + +Suppose we want the AI model to respond with information that it does not have, for example the current temperature at a given location. + +We can provide the AI model with metadata about our own functions that it can use to retrieve that information as it processes your prompt. + +For example, if during the processing of a prompt, the AI Model determines that it needs additional information about the temperature in a given location, it will start a server side generated request/response interaction. The AI Model invokes a client side function. +The AI Model provides method invocation details as JSON and it is the responsibility of the client to execute that function and return the response. + +Spring AI greatly simplifies the code you need to write to support function invocation. +It brokers the function invocation conversation for you. +You can simply provide your function definition as a `@Bean` and then provide the bean name of the function in your prompt options. +You can also reference multiple function bean names in your prompt. + +== Quick Start + +Let's create a chatbot that answer questions by calling our own function. +To support the response of the chatbot, we will register our own function that takes a location and returns the current weather in that location. + +When the response to the prompt to the model needs to answer a question such as `"What’s the weather like in Boston?"` the AI model will invoke the client providing the location value as an argument to be passed to the function. This RPC-like data is passed as JSON. + +Our function can some SaaS based weather service API and returns the weather response back to the model to complete the conversation. +In this example we will use a simple implementation named `MockWeatherService` that hard codes the temperature for various locations. + +The following `MockWeatherService.java` represents the weather service API: + +[source,java] +---- +public class MockWeatherService implements Function { + + public enum Unit { C, F } + public record Request(String location, Unit unit) {} + public record Response(double temp, Unit unit) {} + + public Response apply(Request request) { + return new Response(30.0, Unit.C); + } +} +---- + +=== Registering Functions as Beans + +With the link:../bedrock/bedrock-anthropic3.html#_auto_configuration[BedrockAnthropic3ChatModel Auto-Configuration] you have multiple ways to register custom functions as beans in the Spring context. + +We start with describing the most POJO friendly options. + +==== Plain Java Functions + +In this approach you define `@Beans` in your application context as you would any other Spring managed object. + +Internally, Spring AI `ChatModel` will create an instance of a `FunctionCallbackWrapper` wrapper that adds the logic for it being invoked via the AI model. +The name of the `@Bean` is passed as a `ChatOption`. + + +[source,java] +---- +@Configuration +static class Config { + + @Bean + @Description("Get the weather in location") // function description + public Function weatherFunction1() { + return new MockWeatherService(); + } + ... +} +---- + +The `@Description` annotation is optional and provides a function description (2) that helps the model understand when to call the function. +It is an important property to set to help the AI model determine what client side function to invoke. + +Another option to provide the description of the function is to use the `@JsonClassDescription` annotation on the `MockWeatherService.Request` to provide the function description: + +[source,java] +---- + +@Configuration +static class Config { + + @Bean + public Function currentWeather3() { // (1) bean name as function name. + return new MockWeatherService(); + } + ... +} + +@JsonClassDescription("Get the weather in location") // (2) function description +public record Request(String location, Unit unit) {} +---- + +It is a best practice to annotate the request object with information such that the generated JSON schema of that function is as descriptive as possible to help the AI model pick the correct function to invoke. + +The link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithFunctionBeanIT.java.java[FunctionCallWithFunctionBeanIT.java] demonstrates this approach. + + +==== FunctionCallback Wrapper + +Another way to register a function is to create a `FunctionCallbackWrapper` wrapper like this: + +[source,java] +---- +@Configuration +static class Config { + + @Bean + public FunctionCallback weatherFunctionInfo() { + + return new FunctionCallbackWrapper<>("CurrentWeather", // (1) function name + "Get the weather in location", // (2) function description + (response) -> "" + response.temp() + response.unit(), // (3) Response Converter + new MockWeatherService()); // function code + } + ... +} +---- + +It wraps the 3rd party `MockWeatherService` function and registers it as a `CurrentWeather` function with the `BedrockAnthropic3ChatModel`. +It also provides a description (2) and an optional response converter (3) to convert the response into a text as expected by the model. + +NOTE: By default, the response converter does a JSON serialization of the Response object. + +NOTE: The `FunctionCallbackWrapper` internally resolves the function call signature based on the `MockWeatherService.Request` class. + +=== Specifying functions in Chat Options + +To let the model know and call your `CurrentWeather` function you need to enable it in your prompt requests: + +[source,java] +---- +BedrockAnthropic3ChatModel chatModel = ... + +UserMessage userMessage = new UserMessage("What's the weather like in Paris?"); + +ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), + Anthropic3ChatOptions.builder().withFunction("CurrentWeather").build())); // (1) Enable the function + +logger.info("Response: {}", response); +---- + +// NOTE: You can can have multiple functions registered in your `ChatModel` but only those enabled in the prompt request will be considered for the function calling. + +Above user question will trigger 3 calls to `CurrentWeather` function (one for each city) and produce the final response. + +=== Register/Call Functions with Prompt Options + +In addition to the auto-configuration you can register callback functions, dynamically, with your Prompt requests: + +[source,java] +---- +BedrockAnthropic3ChatModel chatModel = ... + +UserMessage userMessage = new UserMessage("What's the weather like in Paris?"); + +var promptOptions = Anthropic3ChatOptions.builder() + .withFunctionCallbacks(List.of(new FunctionCallbackWrapper<>( + "CurrentWeather", // name + "Get the weather in location", // function description + new MockWeatherService()))) // function code + .build(); + +ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); +---- + +NOTE: The in-prompt registered functions are enabled by default for the duration of this request. + +This approach allows to dynamically chose different functions to be called based on the user input. + +The https://github.com/spring-projects/spring-ai/blob/main/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithPromptFunctionIT.java[FunctionCallWithPromptFunctionIT.java] integration test provides a complete example of how to register a function with the `BedrockAnthropic3ChatModel` and use it in a prompt request. + diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc index b351d4b5941..c704e832202 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/functions.adoc @@ -13,4 +13,5 @@ Spring AI currently supports Function invocation for the following AI Models * Mistral AI: Refer to the xref:api/chat/functions/mistralai-chat-functions.adoc[Mistral AI function invocation docs]. * Anthropic Claude: Refer to the xref:api/chat/functions/anthropic-chat-functions.adoc[Anthropic Claude function invocation docs]. * MiniMax : Refer to the xref:api/chat/functions/minimax-chat-functions.adoc[MiniMax function invocation docs]. -* ZhiPu AI : Refer to the xref:api/chat/functions/zhipuai-chat-functions.adoc[ZhiPu AI function invocation docs]. \ No newline at end of file +* ZhiPu AI : Refer to the xref:api/chat/functions/zhipuai-chat-functions.adoc[ZhiPu AI function invocation docs]. +* Amazon Bedrock Anthropic 3 : Refer to the xref:api/chat/functions/bedrock/bedrock-anthropic3-chat-functions.adoc[Amazon Bedrock Anthropic3 function invocation docs]. \ No newline at end of file