diff --git a/README.md b/README.md index fe4aa92acdb..33fff49f2a2 100644 --- a/README.md +++ b/README.md @@ -88,7 +88,7 @@ You can find more details in the [Reference Documentation](https://docs.spring.i Spring AI supports many AI models. For an overview see here. Specific models currently supported are * OpenAI * Azure OpenAI -* Amazon Bedrock (Anthropic, Llama, Cohere, Titan, Jurassic2) +* Amazon Bedrock (Anthropic, Llama, Cohere, Titan, Jurassic2, Mistral) * HuggingFace * Google VertexAI (PaLM2, Gemini) * Mistral AI diff --git a/models/spring-ai-bedrock/README.md b/models/spring-ai-bedrock/README.md index 19e48518a60..782af7a8535 100644 --- a/models/spring-ai-bedrock/README.md +++ b/models/spring-ai-bedrock/README.md @@ -8,4 +8,5 @@ - [Titan Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-titan.html) - [Titan Embedding Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/embeddings/bedrock-titan-embedding.html) - [Jurassic2 Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-jurassic2.html) +- [Mistral Chat Documentation](https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/bedrock/bedrock-mistral.html) diff --git a/models/spring-ai-bedrock/pom.xml b/models/spring-ai-bedrock/pom.xml index e3b79d30bda..ae7913822a0 100644 --- a/models/spring-ai-bedrock/pom.xml +++ b/models/spring-ai-bedrock/pom.xml @@ -29,6 +29,12 @@ ${project.parent.version} + + org.springframework.ai + spring-ai-retry + ${project.parent.version} + + org.springframework spring-web diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatModel.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatModel.java new file mode 100644 index 00000000000..660837960dd --- /dev/null +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatModel.java @@ -0,0 +1,146 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.mistral; + +import java.util.List; + +import org.springframework.ai.bedrock.BedrockUsage; +import org.springframework.ai.bedrock.MessageToPromptConverter; +import org.springframework.ai.bedrock.mistral.api.MistralChatBedrockApi; +import org.springframework.ai.bedrock.mistral.api.MistralChatBedrockApi.MistralChatRequest; +import org.springframework.ai.bedrock.mistral.api.MistralChatBedrockApi.MistralChatResponse; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.Assert; + +import reactor.core.publisher.Flux; + +/** + * @author Wei Jiang + * @since 1.0.0 + */ +public class BedrockMistralChatModel implements ChatModel, StreamingChatModel { + + private final MistralChatBedrockApi chatApi; + + private final BedrockMistralChatOptions defaultOptions; + + /** + * The retry template used to retry the Bedrock API calls. + */ + private final RetryTemplate retryTemplate; + + public BedrockMistralChatModel(MistralChatBedrockApi chatApi) { + this(chatApi, BedrockMistralChatOptions.builder().build()); + } + + public BedrockMistralChatModel(MistralChatBedrockApi chatApi, BedrockMistralChatOptions options) { + this(chatApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + public BedrockMistralChatModel(MistralChatBedrockApi chatApi, BedrockMistralChatOptions options, + RetryTemplate retryTemplate) { + Assert.notNull(chatApi, "MistralChatBedrockApi must not be null"); + Assert.notNull(options, "BedrockMistralChatOptions must not be null"); + Assert.notNull(retryTemplate, "RetryTemplate must not be null"); + + this.chatApi = chatApi; + this.defaultOptions = options; + this.retryTemplate = retryTemplate; + } + + @Override + public ChatResponse call(Prompt prompt) { + + MistralChatRequest request = createRequest(prompt); + + return this.retryTemplate.execute(ctx -> { + MistralChatResponse response = this.chatApi.chatCompletion(request); + + List generations = response.outputs().stream().map(g -> { + return new Generation(g.text()); + }).toList(); + + return new ChatResponse(generations); + }); + } + + public Flux stream(Prompt prompt) { + + MistralChatRequest request = createRequest(prompt); + + return this.retryTemplate.execute(ctx -> { + return this.chatApi.chatCompletionStream(request).map(g -> { + List generations = g.outputs().stream().map(output -> { + Generation generation = new Generation(output.text()); + + if (g.amazonBedrockInvocationMetrics() != null) { + Usage usage = BedrockUsage.from(g.amazonBedrockInvocationMetrics()); + generation.withGenerationMetadata(ChatGenerationMetadata.from(output.stopReason(), usage)); + } + + return generation; + }).toList(); + + return new ChatResponse(generations); + }); + }); + } + + /** + * Test access. + */ + MistralChatRequest createRequest(Prompt prompt) { + final String promptValue = MessageToPromptConverter.create().toPrompt(prompt.getInstructions()); + + var request = MistralChatRequest.builder(promptValue) + .withTemperature(this.defaultOptions.getTemperature()) + .withTopP(this.defaultOptions.getTopP()) + .withTopK(this.defaultOptions.getTopK()) + .withMaxTokens(this.defaultOptions.getMaxTokens()) + .withStopSequences(this.defaultOptions.getStopSequences()) + .build(); + + if (prompt.getOptions() != null) { + if (prompt.getOptions() instanceof ChatOptions runtimeOptions) { + BedrockMistralChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions, + ChatOptions.class, BedrockMistralChatOptions.class); + request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, MistralChatRequest.class); + } + else { + throw new IllegalArgumentException("Prompt options are not of type ChatOptions: " + + prompt.getOptions().getClass().getSimpleName()); + } + } + + return request; + } + + @Override + public ChatOptions getDefaultOptions() { + return defaultOptions; + } + +} diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatOptions.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatOptions.java new file mode 100644 index 00000000000..e1447d5a1ad --- /dev/null +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatOptions.java @@ -0,0 +1,155 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.mistral; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonInclude.Include; + +import org.springframework.ai.chat.prompt.ChatOptions; + +/** + * @author Wei Jiang + * @since 1.0.0 + */ +@JsonInclude(Include.NON_NULL) +public class BedrockMistralChatOptions implements ChatOptions { + + /** + * The temperature value controls the randomness of the generated text. Use a lower + * value to decrease randomness in the response. + */ + private @JsonProperty("temperature") Float temperature; + + /** + * (optional) The maximum cumulative probability of tokens to consider when sampling. + * The generative uses combined Top-k and nucleus sampling. Nucleus sampling considers + * the smallest set of tokens whose probability sum is at least topP. + */ + private @JsonProperty("top_p") Float topP; + + /** + * (optional) Specify the number of token choices the generative uses to generate the + * next token. + */ + private @JsonProperty("top_p") Integer topK; + + /** + * (optional) Specify the maximum number of tokens to use in the generated response. + */ + private @JsonProperty("max_tokens") Integer maxTokens; + + /** + * (optional) Configure up to four sequences that the generative recognizes. After a + * stop sequence, the generative stops generating further tokens. The returned text + * doesn't contain the stop sequence. + */ + private @JsonProperty("stop") List stopSequences; + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private final BedrockMistralChatOptions options = new BedrockMistralChatOptions(); + + public Builder withTemperature(Float temperature) { + this.options.setTemperature(temperature); + return this; + } + + public Builder withTopP(Float topP) { + this.options.setTopP(topP); + return this; + } + + public Builder withTopK(Integer topK) { + this.options.setTopK(topK); + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + this.options.setMaxTokens(maxTokens); + return this; + } + + public Builder withStopSequences(List stopSequences) { + this.options.setStopSequences(stopSequences); + return this; + } + + public BedrockMistralChatOptions build() { + return this.options; + } + + } + + public void setTemperature(Float temperature) { + this.temperature = temperature; + } + + @Override + public Float getTemperature() { + return this.temperature; + } + + public void setTopP(Float topP) { + this.topP = topP; + } + + @Override + public Float getTopP() { + return this.topP; + } + + public void setTopK(Integer topK) { + this.topK = topK; + } + + @Override + public Integer getTopK() { + return this.topK; + } + + public Integer getMaxTokens() { + return maxTokens; + } + + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + public List getStopSequences() { + return stopSequences; + } + + public void setStopSequences(List stopSequences) { + this.stopSequences = stopSequences; + } + + public static BedrockMistralChatOptions fromOptions(BedrockMistralChatOptions fromOptions) { + return builder().withTemperature(fromOptions.getTemperature()) + .withTopP(fromOptions.getTopP()) + .withTopK(fromOptions.getTopK()) + .withMaxTokens(fromOptions.getMaxTokens()) + .withStopSequences(fromOptions.getStopSequences()) + .build(); + } + +} diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/mistral/api/MistralChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/mistral/api/MistralChatBedrockApi.java new file mode 100644 index 00000000000..da6c861fcba --- /dev/null +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/mistral/api/MistralChatBedrockApi.java @@ -0,0 +1,272 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.mistral.api; + +import java.time.Duration; +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.databind.ObjectMapper; + +import reactor.core.publisher.Flux; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +import org.springframework.ai.bedrock.api.AbstractBedrockApi; +import org.springframework.ai.bedrock.mistral.api.MistralChatBedrockApi.MistralChatRequest; +import org.springframework.ai.bedrock.mistral.api.MistralChatBedrockApi.MistralChatResponse; +import org.springframework.ai.model.ModelDescription; +import org.springframework.util.Assert; + +/** + * Java client for the Bedrock Mistral chat model. + * https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral-text-completion.html + * + * @author Wei Jiang + * @since 1.0.0 + */ +// @formatter:off +public class MistralChatBedrockApi extends AbstractBedrockApi { + + /** + * Create a new MistralChatBedrockApi instance using the default credentials provider chain, the default object + * mapper, default temperature and topP values. + * + * @param modelId The model id to use. See the {@link MistralChatModel} for the supported models. + * @param region The AWS region to use. + */ + public MistralChatBedrockApi(String modelId, String region) { + super(modelId, region); + } + + /** + * Create a new MistralChatBedrockApi instance using the provided credentials provider, region and object mapper. + * + * @param modelId The model id to use. See the {@link MistralChatModel} for the supported models. + * @param credentialsProvider The credentials provider to connect to AWS. + * @param region The AWS region to use. + * @param objectMapper The object mapper to use for JSON serialization and deserialization. + */ + public MistralChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region, + ObjectMapper objectMapper) { + super(modelId, credentialsProvider, region, objectMapper); + } + + /** + * Create a new MistralChatBedrockApi instance using the default credentials provider chain, the default object + * mapper, default temperature and topP values. + * + * @param modelId The model id to use. See the {@link MistralChatModel} for the supported models. + * @param region The AWS region to use. + * @param timeout The timeout to use. + */ + public MistralChatBedrockApi(String modelId, String region, Duration timeout) { + super(modelId, region, timeout); + } + + /** + * Create a new MistralChatBedrockApi instance using the provided credentials provider, region and object mapper. + * + * @param modelId The model id to use. See the {@link MistralChatModel} for the supported models. + * @param credentialsProvider The credentials provider to connect to AWS. + * @param region The AWS region to use. + * @param objectMapper The object mapper to use for JSON serialization and deserialization. + * @param timeout The timeout to use. + */ + public MistralChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, String region, + ObjectMapper objectMapper, Duration timeout) { + super(modelId, credentialsProvider, region, objectMapper, timeout); + } + + /** + * Create a new MistralChatBedrockApi instance using the provided credentials provider, region and object mapper. + * + * @param modelId The model id to use. See the {@link MistralChatModel} for the supported models. + * @param credentialsProvider The credentials provider to connect to AWS. + * @param region The AWS region to use. + * @param objectMapper The object mapper to use for JSON serialization and deserialization. + * @param timeout The timeout to use. + */ + public MistralChatBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region, + ObjectMapper objectMapper, Duration timeout) { + super(modelId, credentialsProvider, region, objectMapper, timeout); + } + + /** + * MistralChatRequest encapsulates the request parameters for the Mistral model. + * + * @param prompt The input prompt to generate the response from. + * @param temperature (optional) Use a lower value to decrease randomness in the response. + * @param topP (optional) Use a lower value to ignore less probable options. Set to 0 or 1.0 to disable. + * @param topK (optional) Specify the number of token choices the model uses to generate the next token. + * @param maxTokens (optional) Specify the maximum number of tokens to use in the generated response. + * @param stopSequences (optional) Configure up to four sequences that the model recognizes. After a stop sequence, + * the model stops generating further tokens. The returned text doesn't contain the stop sequence. + */ + @JsonInclude(Include.NON_NULL) + public record MistralChatRequest( + @JsonProperty("prompt") String prompt, + @JsonProperty("temperature") Float temperature, + @JsonProperty("top_p") Float topP, + @JsonProperty("top_k") Integer topK, + @JsonProperty("max_tokens") Integer maxTokens, + @JsonProperty("stop") List stopSequences) { + + /** + * Get MistralChatRequest builder. + * @param prompt compulsory request prompt parameter. + * @return MistralChatRequest builder. + */ + public static Builder builder(String prompt) { + return new Builder(prompt); + } + + /** + * Builder for the MistralChatRequest. + */ + public static class Builder { + private final String prompt; + private Float temperature; + private Float topP; + private Integer topK; + private Integer maxTokens; + private List stopSequences; + + public Builder(String prompt) { + this.prompt = prompt; + } + + public Builder withTemperature(Float temperature) { + this.temperature = temperature; + return this; + } + + public Builder withTopP(Float topP) { + this.topP = topP; + return this; + } + + public Builder withTopK(Integer topK) { + this.topK = topK; + return this; + } + + public Builder withMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + public Builder withStopSequences(List stopSequences) { + this.stopSequences = stopSequences; + return this; + } + + public MistralChatRequest build() { + return new MistralChatRequest( + prompt, + temperature, + topP, + topK, + maxTokens, + stopSequences + ); + } + } + } + + /** + * MistralChatResponse encapsulates the response parameters for the Mistral model. + * + * @param A list of outputs from the model. Each output has the following fields. + */ + @JsonInclude(Include.NON_NULL) + public record MistralChatResponse( + @JsonProperty("outputs") List outputs, + @JsonProperty("amazon-bedrock-invocationMetrics") AmazonBedrockInvocationMetrics amazonBedrockInvocationMetrics) { + + /** + * Generated result along with the likelihoods for tokens requested. + * + * @param text The text that the model generated. + * @param stopReason The reason why the response stopped generating text. + */ + public record Generation( + @JsonProperty("text") String text, + @JsonProperty("stop_reason") String stopReason) { + } + + } + + /** + * Mistral models version. + */ + public enum MistralChatModel implements ModelDescription { + + /** + * mistral.mistral-7b-instruct-v0:2 + */ + MISTRAL_7B_INSTRUCT("mistral.mistral-7b-instruct-v0:2"), + + /** + * mistral.mixtral-8x7b-instruct-v0:1 + */ + MISTRAL_8X7B_INSTRUCT("mistral.mixtral-8x7b-instruct-v0:1"), + + /** + * mistral.mistral-large-2402-v1:0 + */ + MISTRAL_LARGE("mistral.mistral-large-2402-v1:0"), + + /** + * mistral.mistral-small-2402-v1:0 + */ + MISTRAL_SMALL("mistral.mistral-small-2402-v1:0"); + + private final String id; + + /** + * @return The model id. + */ + public String id() { + return id; + } + + MistralChatModel(String value) { + this.id = value; + } + + @Override + public String getModelName() { + return this.id; + } + } + + @Override + public MistralChatResponse chatCompletion(MistralChatRequest mistralRequest) { + Assert.notNull(mistralRequest, "'MistralChatRequest' must not be null"); + return this.internalInvocation(mistralRequest, MistralChatResponse.class); + } + + @Override + public Flux chatCompletionStream(MistralChatRequest mistralRequest) { + Assert.notNull(mistralRequest, "'MistralChatRequest' must not be null"); + return this.internalInvocationStream(mistralRequest, MistralChatResponse.class); + } + +} +//@formatter:on \ No newline at end of file diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatCreateRequestTests.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatCreateRequestTests.java new file mode 100644 index 00000000000..356223add3c --- /dev/null +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatCreateRequestTests.java @@ -0,0 +1,67 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.mistral; + +import java.time.Duration; +import java.util.List; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Test; + +import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +import org.springframework.ai.bedrock.mistral.api.MistralChatBedrockApi; +import org.springframework.ai.bedrock.mistral.api.MistralChatBedrockApi.MistralChatModel; +import org.springframework.ai.bedrock.mistral.api.MistralChatBedrockApi.MistralChatRequest; +import org.springframework.ai.chat.prompt.Prompt; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Wei Jiang + * @since 1.0.0 + */ +public class BedrockMistralChatCreateRequestTests { + + private MistralChatBedrockApi chatApi = new MistralChatBedrockApi(MistralChatModel.MISTRAL_8X7B_INSTRUCT.id(), + EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(), + Duration.ofMinutes(2)); + + @Test + public void createRequestWithChatOptions() { + + var client = new BedrockMistralChatModel(chatApi, + BedrockMistralChatOptions.builder() + .withTemperature(66.6f) + .withTopK(66) + .withTopP(0.66f) + .withMaxTokens(678) + .withStopSequences(List.of("stop1", "stop2")) + .build()); + + MistralChatRequest request = client.createRequest(new Prompt("Test message content")); + + assertThat(request.prompt()).isNotEmpty(); + + assertThat(request.temperature()).isEqualTo(66.6f); + assertThat(request.topK()).isEqualTo(66); + assertThat(request.topP()).isEqualTo(0.66f); + assertThat(request.maxTokens()).isEqualTo(678); + assertThat(request.stopSequences()).containsExactly("stop1", "stop2"); + } + +} diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatModelIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatModelIT.java new file mode 100644 index 00000000000..1c172aa2c8f --- /dev/null +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatModelIT.java @@ -0,0 +1,219 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.mistral; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.time.Duration; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.bedrock.mistral.api.MistralChatBedrockApi; +import org.springframework.ai.bedrock.mistral.api.MistralChatBedrockApi.MistralChatModel; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.converter.BeanOutputConverter; +import org.springframework.ai.converter.ListOutputConverter; +import org.springframework.ai.converter.MapOutputConverter; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.core.convert.support.DefaultConversionService; +import org.springframework.core.io.Resource; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import reactor.core.publisher.Flux; +import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +/** + * @author Wei Jiang + * @since 1.0.0 + */ +@SpringBootTest +@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") +public class BedrockMistralChatModelIT { + + @Autowired + private BedrockMistralChatModel chatModel; + + @Value("classpath:/prompts/system-message.st") + private Resource systemResource; + + @Test + void multipleStreamAttempts() { + + Flux joke1Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a joke?"))); + Flux joke2Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a toy joke?"))); + + String joke1 = joke1Stream.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + String joke2 = joke2Stream.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + + assertThat(joke1).isNotBlank(); + assertThat(joke2).isNotBlank(); + } + + @Test + void roleTest() { + String request = "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."; + String name = "Bob"; + String voice = "pirate"; + UserMessage userMessage = new UserMessage(request); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice)); + Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); + ChatResponse response = chatModel.call(prompt); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); + } + + @Test + void listOutputConverter() { + DefaultConversionService conversionService = new DefaultConversionService(); + ListOutputConverter outputConverter = new ListOutputConverter(conversionService); + + String format = outputConverter.getFormat(); + String template = """ + List five {subject} + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, + Map.of("subject", "ice cream flavors.", "format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = this.chatModel.call(prompt).getResult(); + + List list = outputConverter.convert(generation.getOutput().getContent()); + assertThat(list).hasSize(5); + } + + @Test + void mapOutputConverter() { + MapOutputConverter outputConverter = new MapOutputConverter(); + + String format = outputConverter.getFormat(); + String template = """ + Remove Markdown code blocks from the output. + Provide me a List of {subject} + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, + Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = chatModel.call(prompt).getResult(); + + Map result = outputConverter.convert(generation.getOutput().getContent()); + assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); + + } + + record ActorsFilmsRecord(String actor, List movies) { + } + + @Test + void beanOutputConverterRecords() { + + BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class); + + String format = outputConverter.getFormat(); + String template = """ + Generate the filmography of 5 movies for Tom Hanks. + {format} + Remove Markdown code blocks from the output. + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = chatModel.call(prompt).getResult(); + + ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + + @Test + void beanStreamOutputConverterRecords() { + + BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class); + + String format = outputConverter.getFormat(); + String template = """ + Generate the filmography of 5 movies for Tom Hanks. + {format} + Remove Markdown code blocks from the output. + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + + String generationTextFromStream = chatModel.stream(prompt) + .collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + + ActorsFilmsRecord actorsFilms = outputConverter.convert(generationTextFromStream); + System.out.println(actorsFilms); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + + @SpringBootConfiguration + public static class TestConfiguration { + + @Bean + public MistralChatBedrockApi mistralApi() { + return new MistralChatBedrockApi(MistralChatModel.MISTRAL_8X7B_INSTRUCT.id(), + EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(), + Duration.ofMinutes(2)); + } + + @Bean + public BedrockMistralChatModel mistralChatModel(MistralChatBedrockApi cohereApi) { + return new BedrockMistralChatModel(cohereApi); + } + + } + +} diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/mistral/api/MistralChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/mistral/api/MistralChatBedrockApiIT.java new file mode 100644 index 00000000000..3389a9e520b --- /dev/null +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/mistral/api/MistralChatBedrockApiIT.java @@ -0,0 +1,82 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.mistral.api; + +import java.time.Duration; +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.bedrock.mistral.api.MistralChatBedrockApi.MistralChatRequest; +import org.springframework.ai.bedrock.mistral.api.MistralChatBedrockApi.MistralChatResponse; +import org.springframework.ai.bedrock.mistral.api.MistralChatBedrockApi.MistralChatModel; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import reactor.core.publisher.Flux; +import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Wei Jiang + * @since 1.0.0 + */ +@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") +public class MistralChatBedrockApiIT { + + private MistralChatBedrockApi mistralChatApi = new MistralChatBedrockApi( + MistralChatModel.MISTRAL_8X7B_INSTRUCT.id(), EnvironmentVariableCredentialsProvider.create(), + Region.US_EAST_1.id(), new ObjectMapper(), Duration.ofMinutes(2)); + + @Test + public void chatCompletion() { + + MistralChatRequest request = MistralChatRequest.builder("Hello, Who are you?") + .withTemperature(0.9f) + .withTopP(0.9f) + .build(); + + MistralChatResponse response = mistralChatApi.chatCompletion(request); + + assertThat(response).isNotNull(); + assertThat(response.outputs()).isNotEmpty(); + assertThat(response.outputs().get(0)).isNotNull(); + assertThat(response.outputs().get(0).text()).isNotNull(); + assertThat(response.outputs().get(0).stopReason()).isNotNull(); + } + + @Test + public void chatCompletionStream() { + + MistralChatRequest request = MistralChatRequest.builder("Hello, Who are you?") + .withTemperature(0.9f) + .withTopP(0.9f) + .build(); + Flux responseStream = mistralChatApi.chatCompletionStream(request); + List responses = responseStream.collectList().block(); + + assertThat(responses).isNotNull(); + assertThat(responses).hasSizeGreaterThan(10); + assertThat(responses.get(0).outputs()).isNotEmpty(); + + MistralChatResponse lastResponse = responses.get(responses.size() - 1); + assertThat(lastResponse.amazonBedrockInvocationMetrics()).isNotNull(); + } + +} diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/images/bedrock/bedrock-mistral-chat-low-level-api.png b/spring-ai-docs/src/main/antora/modules/ROOT/images/bedrock/bedrock-mistral-chat-low-level-api.png new file mode 100644 index 00000000000..e1276c7b22f Binary files /dev/null and b/spring-ai-docs/src/main/antora/modules/ROOT/images/bedrock/bedrock-mistral-chat-low-level-api.png differ diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index 120a8e5c8b2..380b39c3852 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -16,6 +16,7 @@ **** xref:api/chat/bedrock/bedrock-cohere.adoc[Cohere] **** xref:api/chat/bedrock/bedrock-titan.adoc[Titan] **** xref:api/chat/bedrock/bedrock-jurassic2.adoc[Jurassic2] +**** xref:api/chat/bedrock/bedrock-mistral.adoc[Mistral] *** xref:api/chat/huggingface.adoc[HuggingFace] *** xref:api/chat/google-vertexai.adoc[Google VertexAI] **** xref:api/chat/vertexai-palm2-chat.adoc[VertexAI PaLM2 ] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/bedrock.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/bedrock.adoc index f8b2b2062a1..dcad2f6e5e9 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/bedrock.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/bedrock.adoc @@ -108,3 +108,4 @@ For more information, refer to the documentation below for each supported model. * xref:api/chat/bedrock/bedrock-titan.adoc[Spring AI Bedrock Titan Chat]: `spring.ai.bedrock.titan.chat.enabled=true` * xref:api/embeddings/bedrock-titan-embedding.adoc[Spring AI Bedrock Titan Embeddings]: `spring.ai.bedrock.titan.embedding.enabled=true` * xref:api/chat/bedrock/bedrock-jurassic2.adoc[Spring AI Bedrock Ai21 Jurassic2 Chat]: `spring.ai.bedrock.jurassic2.chat.enabled=true` +* xref:api/chat/bedrock/bedrock-mistral.adoc[Spring AI Bedrock Mistral Chat]: `spring.ai.bedrock.mistral.chat.enabled=true` diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-mistral.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-mistral.adoc new file mode 100644 index 00000000000..0d329e8fa7a --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/bedrock/bedrock-mistral.adoc @@ -0,0 +1,251 @@ += Mistral Chat + +Provides Bedrock Mistral chat model. +Integrate generative AI capabilities into essential apps and workflows that improve business outcomes. + +The https://aws.amazon.com/bedrock/mistral/[AWS Bedrock Mistral Model Page] and https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html[Amazon Bedrock User Guide] contains detailed information on how to use the AWS hosted model. + +== Prerequisites + +Refer to the xref:api/bedrock.adoc[Spring AI documentation on Amazon Bedrock] for setting up API access. + +=== Add Repositories and BOM + +Spring AI artifacts are published in Spring Milestone and Snapshot repositories. Refer to the xref:getting-started.adoc#repositories[Repositories] section to add these repositories to your build system. + +To help with dependency management, Spring AI provides a BOM (bill of materials) to ensure that a consistent version of Spring AI is used throughout the entire project. Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build system. + + +== Auto-configuration + +Add the `spring-ai-bedrock-ai-spring-boot-starter` dependency to your project's Maven `pom.xml` file: + +[source,xml] +---- + + org.springframework.ai + spring-ai-bedrock-ai-spring-boot-starter + +---- + +or to your Gradle `build.gradle` build file. + +[source,gradle] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-bedrock-ai-spring-boot-starter' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +=== Enable Mistral Chat Support + +By default the Mistral model is disabled. +To enable it set the `spring.ai.bedrock.mistral.chat.enabled` property to `true`. +Exporting environment variable is one way to set this configuration property: + +[source,shell] +---- +export SPRING_AI_BEDROCK_MISTRAL_CHAT_ENABLED=true +---- + +=== Chat Properties + +The prefix `spring.ai.bedrock.aws` is the property prefix to configure the connection to AWS Bedrock. + +[cols="3,3,3"] +|==== +| Property | Description | Default + +| spring.ai.bedrock.aws.region | AWS region to use. | us-east-1 +| spring.ai.bedrock.aws.timeout | AWS timeout to use. | 5m +| spring.ai.bedrock.aws.access-key | AWS access key. | - +| spring.ai.bedrock.aws.secret-key | AWS secret key. | - +|==== + +The prefix `spring.ai.bedrock.mistral.chat` is the property prefix that configures the chat model implementation for Mistral. + +[cols="2,5,1"] +|==== +| Property | Description | Default + +| spring.ai.bedrock.mistral.chat.enabled | Enable or disable support for Mistral | false +| spring.ai.bedrock.mistral.chat.model | The model id to use. See the https://github.com/spring-projects/spring-ai/blob/4ba9a3cd689b9fd3a3805f540debe398a079c6ef/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/mistral/api/MistralChatBedrockApi.java#L326C14-L326C29[MistralChatModel] for the supported models. | mistral.mistral-large-2402-v1:0 +| spring.ai.bedrock.mistral.chat.options.temperature | Controls the randomness of the output. Values can range over [0.0,1.0] | 0.7 +| spring.ai.bedrock.mistral.chat.options.topP | The maximum cumulative probability of tokens to consider when sampling. | AWS Bedrock default +| spring.ai.bedrock.mistral.chat.options.topK | Specify the number of token choices the model uses to generate the next token | AWS Bedrock default +| spring.ai.bedrock.mistral.chat.options.maxTokens | Specify the maximum number of tokens to use in the generated response. | AWS Bedrock default +| spring.ai.bedrock.mistral.chat.options.stopSequences | Configure up to four sequences that the model recognizes. | AWS Bedrock default +|==== + +Look at the https://github.com/spring-projects/spring-ai/blob/4ba9a3cd689b9fd3a3805f540debe398a079c6ef/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/mistral/api/MistralChatBedrockApi.java#L326C14-L326C29[MistralChatModel] for other model IDs. +Supported values are: `mistral.mistral-7b-instruct-v0:2`, `mistral.mixtral-8x7b-instruct-v0:1`, `mistral.mistral-large-2402-v1:0` and `mistral.mistral-small-2402-v1:0`. +Model ID values can also be found in the https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html[AWS Bedrock documentation for base model IDs]. + +TIP: All properties prefixed with `spring.ai.bedrock.mistral.chat.options` can be overridden at runtime by adding a request specific <> to the `Prompt` call. + +== Runtime Options [[chat-options]] + +The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatOptions.java[BedrockMistralChatOptions.java] provides model configurations, such as temperature, topK, topP, etc. + +On start-up, the default options can be configured with the `BedrockMistralChatModel(api, options)` constructor or the `spring.ai.bedrock.mistral.chat.options.*` properties. + +At run-time you can override the default options by adding new, request specific, options to the `Prompt` call. +For example to override the default temperature for a specific request: + +[source,java] +---- +ChatResponse response = chatModel.call( + new Prompt( + "Generate the names of 5 famous pirates.", + BedrockMistralChatOptions.builder() + .withTemperature(0.4) + .build() + )); +---- + +TIP: In addition to the model specific https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatOptions.java[BedrockMistralChatOptions] you can use a portable https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java[ChatOptions] instance, created with the https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptionsBuilder.java[ChatOptionsBuilder#builder()]. + +== Sample Controller + +https://start.spring.io/[Create] a new Spring Boot project and add the `spring-ai-bedrock-ai-spring-boot-starter` to your pom (or gradle) dependencies. + +Add a `application.properties` file, under the `src/main/resources` directory, to enable and configure the Mistral chat model: + +[source] +---- +spring.ai.bedrock.aws.region=eu-central-1 +spring.ai.bedrock.aws.timeout=1000ms +spring.ai.bedrock.aws.access-key=${AWS_ACCESS_KEY_ID} +spring.ai.bedrock.aws.secret-key=${AWS_SECRET_ACCESS_KEY} + +spring.ai.bedrock.mistral.chat.enabled=true +spring.ai.bedrock.mistral.chat.options.temperature=0.8 +---- + +TIP: replace the `regions`, `access-key` and `secret-key` with your AWS credentials. + +This will create a `BedrockMistralChatModel` implementation that you can inject into your class. +Here is an example of a simple `@Controller` class that uses the chat model for text generations. + +[source,java] +---- +@RestController +public class ChatController { + + private final BedrockMistralChatModel chatModel; + + @Autowired + public ChatController(BedrockMistralChatModel chatModel) { + this.chatModel = chatModel; + } + + @GetMapping("/ai/generate") + public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + return Map.of("generation", chatModel.call(message)); + } + + @GetMapping("/ai/generateStream") + public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + Prompt prompt = new Prompt(new UserMessage(message)); + return chatModel.stream(prompt); + } +} +---- + +== Manual Configuration + +The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatModel.java[BedrockMistralChatModel] implements the `ChatModel` and `StreamingChatModel` and uses the <> to connect to the Bedrock Mistral service. + +Add the `spring-ai-bedrock` dependency to your project's Maven `pom.xml` file: + +[source,xml] +---- + + org.springframework.ai + spring-ai-bedrock + +---- + +or to your Gradle `build.gradle` build file. + +[source,gradle] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-bedrock' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +Next, create an https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatModel.java[BedrockMistralChatModel] and use it for text generations: + +[source,java] +---- +MistralChatBedrockApi api = new MistralChatBedrockApi(MistralChatModel.MISTRAL_8X7B_INSTRUCT.id(), + EnvironmentVariableCredentialsProvider.create(), + Region.US_EAST_1.id(), + new ObjectMapper(), + Duration.ofMillis(1000L)); + +BedrockMistralChatModel chatModel = new BedrockMistralChatModel(api, + BedrockMistralChatOptions.builder() + .withTemperature(0.6f) + .withTopK(10) + .withTopP(0.5f) + .withMaxTokens(678) + .build() + +ChatResponse response = chatModel.call( + new Prompt("Generate the names of 5 famous pirates.")); + +// Or with streaming responses +Flux response = chatModel.stream( + new Prompt("Generate the names of 5 famous pirates.")); +---- + +== Low-level MistralChatBedrockApi Client [[low-level-api]] + +The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/mistral/api/MistralChatBedrockApi.java[MistralChatBedrockApi] provides is lightweight Java client on top of AWS Bedrock https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral-text-completion.html[Mistral Command models]. + +Following class diagram illustrates the MistralChatBedrockApi interface and building blocks: + +image::bedrock/bedrock-mistral-chat-low-level-api.png[MistralChatBedrockApi Class Diagram] + +The MistralChatBedrockApi supports the `mistral.mistral-7b-instruct-v0:2`, `mistral.mixtral-8x7b-instruct-v0:1`, `mistral.mistral-large-2402-v1:0` and `mistral.mistral-small-2402-v1:0` models for both synchronous (e.g. `chatCompletion()`) and streaming (e.g. `chatCompletionStream()`) requests. + +Here is a simple snippet how to use the api programmatically: + +[source,java] +---- +MistralChatBedrockApi mistralChatApi = new MistralChatBedrockApi( + MistralChatModel.MISTRAL_8X7B_INSTRUCT.id(), + Region.US_EAST_1.id(), + Duration.ofMillis(1000L)); + +var request = MistralChatRequest + .builder("What is the capital of Bulgaria and what is the size? What it the national anthem?") + .withTemperature(0.5f) + .withTopP(0.8f) + .withTopK(15) + .withMaxTokens(100) + .withStopSequences(List.of("END")) + .build(); + +MistralChatResponse response = mistralChatApi.chatCompletion(request); + +var request = MistralChatRequest + .builder("What is the capital of Bulgaria and what is the size? What it the national anthem?") + .withTemperature(0.5f) + .withTopP(0.8f) + .withTopK(15) + .withMaxTokens(100) + .withStopSequences(List.of("END")) + .build(); + +Flux responseStream = mistralChatApi.chatCompletionStream(request); +List responses = responseStream.collectList().block(); +---- + + diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatmodel.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatmodel.adoc index 9fdce3b424f..8918e34e0eb 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatmodel.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatmodel.adoc @@ -202,6 +202,7 @@ image::spring-ai-chat-completions-clients.jpg[align="center", width="800px"] ** xref:api/chat/bedrock/bedrock-titan.adoc[Titan Chat Completion] ** xref:api/chat/bedrock/bedrock-anthropic.adoc[Anthropic Chat Completion] ** xref:api/chat/bedrock/bedrock-jurassic2.adoc[Jurassic2 Chat Completion] +** xref:api/chat/bedrock/bedrock-mistral.adoc[Mistral Chat Completion] * xref:api/chat/mistralai-chat.adoc[Mistral AI Chat Completion] (streaming & function-calling support) * xref:api/chat/anthropic-chat.adoc[Anthropic Chat Completion] (streaming) diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/structured-output-converter.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/structured-output-converter.adoc index 6346d4e5c16..8d9d9f3b487 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/structured-output-converter.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/structured-output-converter.adoc @@ -210,7 +210,8 @@ The following AI Models have been tested to support List, Map and Bean structure | xref:api/chat/bedrock/bedrock-anthropic.adoc[Bedrock Anthropic 2] | link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatModelIT.java[BedrockAnthropicChatModelIT.java] | xref:api/chat/bedrock/bedrock-anthropic3.adoc[Bedrock Anthropic 3] | link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModelIT.java[BedrockAnthropic3ChatModelIT.java] | xref:api/chat/bedrock/bedrock-cohere.adoc[Bedrock Cohere] | link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatModelIT.java[BedrockCohereChatModelIT.java] -| xref:api/chat/bedrock/bedrock-llama.adoc[Bedrock Llama] | link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModelIT.java[BedrockLlamaChatModelIT.java.java] +| xref:api/chat/bedrock/bedrock-llama.adoc[Bedrock Llama] | link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama/BedrockLlamaChatModelIT.java[BedrockLlamaChatModelIT.java] +| xref:api/chat/bedrock/bedrock-mistral.adoc[Bedrock Mistral] | link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/mistral/BedrockMistralChatModelIT.java[BedrockMistralChatModelIT.java] |==== == Build-in JSON mode diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/getting-started.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/getting-started.adoc index 224df22be4f..b4378566d8c 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/getting-started.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/getting-started.adoc @@ -153,6 +153,7 @@ Each of the following sections in the documentation shows which dependencies you *** xref:api/chat/bedrock/bedrock-titan.adoc[Titan Chat Completion] *** xref:api/chat/bedrock/bedrock-anthropic.adoc[Anthropic Chat Completion] *** xref:api/chat/bedrock/bedrock-jurassic2.adoc[Jurassic2 Chat Completion] +*** xref:api/chat/bedrock/bedrock-mistral.adoc[Mistral Chat Completion] ** xref:api/chat/mistralai-chat.adoc[MistralAI Chat Completion] (streaming and function-calling support) === Image Generation Models diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/mistral/BedrockMistralChatAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/mistral/BedrockMistralChatAutoConfiguration.java new file mode 100644 index 00000000000..2c897a5a5c5 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/mistral/BedrockMistralChatAutoConfiguration.java @@ -0,0 +1,68 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.bedrock.mistral; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionConfiguration; +import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.bedrock.mistral.BedrockMistralChatModel; +import org.springframework.ai.bedrock.mistral.api.MistralChatBedrockApi; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Import; +import org.springframework.retry.support.RetryTemplate; + +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.providers.AwsRegionProvider; + +/** + * {@link AutoConfiguration Auto-configuration} for Bedrock Mistral Chat Client. + * + * @author Wei Jiang + * @since 1.0.0 + */ +@AutoConfiguration(after = { SpringAiRetryAutoConfiguration.class }) +@ConditionalOnClass(MistralChatBedrockApi.class) +@EnableConfigurationProperties({ BedrockMistralChatProperties.class, BedrockAwsConnectionProperties.class }) +@ConditionalOnProperty(prefix = BedrockMistralChatProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true") +@Import(BedrockAwsConnectionConfiguration.class) +public class BedrockMistralChatAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + @ConditionalOnBean({ AwsCredentialsProvider.class, AwsRegionProvider.class }) + public MistralChatBedrockApi mistralChatApi(AwsCredentialsProvider credentialsProvider, + AwsRegionProvider regionProvider, BedrockMistralChatProperties properties, + BedrockAwsConnectionProperties awsProperties) { + return new MistralChatBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(), + new ObjectMapper(), awsProperties.getTimeout()); + } + + @Bean + @ConditionalOnBean(MistralChatBedrockApi.class) + public BedrockMistralChatModel mistralChatModel(MistralChatBedrockApi mistralChatApi, + BedrockMistralChatProperties properties, RetryTemplate retryTemplate) { + + return new BedrockMistralChatModel(mistralChatApi, properties.getOptions(), retryTemplate); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/mistral/BedrockMistralChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/mistral/BedrockMistralChatProperties.java new file mode 100644 index 00000000000..e3952efcecd --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/mistral/BedrockMistralChatProperties.java @@ -0,0 +1,72 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.bedrock.mistral; + +import org.springframework.ai.bedrock.mistral.BedrockMistralChatOptions; +import org.springframework.ai.bedrock.mistral.api.MistralChatBedrockApi; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +/** + * Bedrock Mistral Chat autoconfiguration properties. + * + * @author Wei Jiang + * @since 1.0.0 + */ +@ConfigurationProperties(BedrockMistralChatProperties.CONFIG_PREFIX) +public class BedrockMistralChatProperties { + + public static final String CONFIG_PREFIX = "spring.ai.bedrock.mistral.chat"; + + /** + * Enable Bedrock Mistral Chat Client. False by default. + */ + private boolean enabled = false; + + /** + * Bedrock Mistral Chat generative name. Defaults to + * 'mistral.mistral-large-2402-v1:0'. + */ + private String model = MistralChatBedrockApi.MistralChatModel.MISTRAL_LARGE.id(); + + @NestedConfigurationProperty + private BedrockMistralChatOptions options = BedrockMistralChatOptions.builder().build(); + + public boolean isEnabled() { + return this.enabled; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + public BedrockMistralChatOptions getOptions() { + return this.options; + } + + public void setOptions(BedrockMistralChatOptions options) { + this.options = options; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports index c744be669c2..dfde9b22cef 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports +++ b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -13,6 +13,7 @@ org.springframework.ai.autoconfigure.bedrock.anthropic.BedrockAnthropicChatAutoC org.springframework.ai.autoconfigure.bedrock.anthropic3.BedrockAnthropic3ChatAutoConfiguration org.springframework.ai.autoconfigure.bedrock.titan.BedrockTitanChatAutoConfiguration org.springframework.ai.autoconfigure.bedrock.titan.BedrockTitanEmbeddingAutoConfiguration +org.springframework.ai.autoconfigure.bedrock.mistral.BedrockMistralChatAutoConfiguration org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration org.springframework.ai.autoconfigure.mistralai.MistralAiAutoConfiguration org.springframework.ai.autoconfigure.vectorstore.pgvector.PgVectorStoreAutoConfiguration diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/mistral/BedrockMistralChatAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/mistral/BedrockMistralChatAutoConfigurationIT.java new file mode 100644 index 00000000000..dbef230a05c --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/mistral/BedrockMistralChatAutoConfigurationIT.java @@ -0,0 +1,159 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.bedrock.mistral; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.messages.AssistantMessage; +import reactor.core.publisher.Flux; +import software.amazon.awssdk.regions.Region; + +import org.springframework.ai.autoconfigure.bedrock.BedrockAwsConnectionProperties; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; +import org.springframework.ai.bedrock.mistral.BedrockMistralChatModel; +import org.springframework.ai.bedrock.mistral.api.MistralChatBedrockApi.MistralChatModel; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Wei Jiang + * @since 1.0.0 + */ +@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") +public class BedrockMistralChatAutoConfigurationIT { + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.bedrock.mistral.chat.enabled=true", + "spring.ai.bedrock.aws.access-key=" + System.getenv("AWS_ACCESS_KEY_ID"), + "spring.ai.bedrock.aws.secret-key=" + System.getenv("AWS_SECRET_ACCESS_KEY"), + "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(), + "spring.ai.bedrock.mistral.chat.model=" + MistralChatModel.MISTRAL_SMALL.id(), + "spring.ai.bedrock.mistral.chat.options.temperature=0.5") + .withConfiguration( + AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, BedrockMistralChatAutoConfiguration.class)); + + private final Message systemMessage = new SystemPromptTemplate(""" + You are a helpful AI assistant. Your name is {name}. + You are an AI assistant that helps people find information. + Your name is {name} + You should reply to the user's request with your name and also in the style of a {voice}. + """).createMessage(Map.of("name", "Bob", "voice", "pirate")); + + private final UserMessage userMessage = new UserMessage( + "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); + + @Test + public void chatCompletion() { + contextRunner.run(context -> { + BedrockMistralChatModel mistralChatModel = context.getBean(BedrockMistralChatModel.class); + ChatResponse response = mistralChatModel.call(new Prompt(List.of(userMessage, systemMessage))); + assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard"); + }); + } + + @Test + public void chatCompletionStreaming() { + contextRunner.run(context -> { + + BedrockMistralChatModel mistralChatModel = context.getBean(BedrockMistralChatModel.class); + + Flux response = mistralChatModel.stream(new Prompt(List.of(userMessage, systemMessage))); + + List responses = response.collectList().block(); + assertThat(responses.size()).isGreaterThan(2); + + String stitchedResponseContent = responses.stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + + assertThat(stitchedResponseContent).contains("Blackbeard"); + }); + } + + @Test + public void propertiesTest() { + + new ApplicationContextRunner() + .withPropertyValues("spring.ai.bedrock.mistral.chat.enabled=true", + "spring.ai.bedrock.aws.access-key=ACCESS_KEY", "spring.ai.bedrock.aws.secret-key=SECRET_KEY", + "spring.ai.bedrock.mistral.chat.model=MODEL_XYZ", + "spring.ai.bedrock.aws.region=" + Region.US_EAST_1.id(), + "spring.ai.bedrock.mistral.chat.options.temperature=0.55") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockMistralChatAutoConfiguration.class)) + .run(context -> { + var mistralChatProperties = context.getBean(BedrockMistralChatProperties.class); + var awsProperties = context.getBean(BedrockAwsConnectionProperties.class); + + assertThat(mistralChatProperties.isEnabled()).isTrue(); + assertThat(awsProperties.getRegion()).isEqualTo(Region.US_EAST_1.id()); + + assertThat(mistralChatProperties.getOptions().getTemperature()).isEqualTo(0.55f); + assertThat(mistralChatProperties.getModel()).isEqualTo("MODEL_XYZ"); + + assertThat(awsProperties.getAccessKey()).isEqualTo("ACCESS_KEY"); + assertThat(awsProperties.getSecretKey()).isEqualTo("SECRET_KEY"); + }); + } + + @Test + public void chatCompletionDisabled() { + + // It is disabled by default + new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockMistralChatAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(BedrockMistralChatProperties.class)).isEmpty(); + assertThat(context.getBeansOfType(BedrockMistralChatModel.class)).isEmpty(); + }); + + // Explicitly enable the chat auto-configuration. + new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.mistral.chat.enabled=true") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockMistralChatAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(BedrockMistralChatProperties.class)).isNotEmpty(); + assertThat(context.getBeansOfType(BedrockMistralChatModel.class)).isNotEmpty(); + }); + + // Explicitly disable the chat auto-configuration. + new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.mistral.chat.enabled=false") + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + BedrockMistralChatAutoConfiguration.class)) + .run(context -> { + assertThat(context.getBeansOfType(BedrockMistralChatProperties.class)).isEmpty(); + assertThat(context.getBeansOfType(BedrockMistralChatModel.class)).isEmpty(); + }); + } + +}