diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockUsage.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockUsage.java index 6eb31aa28e7..0ada7895fa2 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockUsage.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/BedrockUsage.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.bedrock; +import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationMetadata; import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationMetrics; import org.springframework.ai.chat.metadata.Usage; import org.springframework.util.Assert; @@ -23,38 +24,45 @@ * {@link Usage} implementation for Bedrock API. * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ public class BedrockUsage implements Usage { public static BedrockUsage from(AmazonBedrockInvocationMetrics usage) { - return new BedrockUsage(usage); + Assert.notNull(usage, "Amazon Bedrock Usage must not be null"); + + return new BedrockUsage(usage.inputTokenCount().longValue(), usage.outputTokenCount().longValue()); } - private final AmazonBedrockInvocationMetrics usage; + public static BedrockUsage from(AmazonBedrockInvocationMetadata metadata) { + Assert.notNull(metadata, "Amazon Bedrock Invocation Metadata must not be null"); - protected BedrockUsage(AmazonBedrockInvocationMetrics usage) { - Assert.notNull(usage, "OpenAI Usage must not be null"); - this.usage = usage; + return new BedrockUsage(metadata.inputTokenCount(), metadata.outputTokenCount()); } - protected AmazonBedrockInvocationMetrics getUsage() { - return this.usage; + private Long promptTokens; + + private Long generationTokens; + + protected BedrockUsage(Long promptTokens, Long generationTokens) { + this.promptTokens = promptTokens; + this.generationTokens = generationTokens; } @Override public Long getPromptTokens() { - return getUsage().inputTokenCount().longValue(); + return this.promptTokens; } @Override public Long getGenerationTokens() { - return getUsage().outputTokenCount().longValue(); + return this.generationTokens; } @Override public String toString() { - return getUsage().toString(); + return "BedrockUsage [promptTokens=" + promptTokens + ", generationTokens=" + generationTokens + "]"; } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClient.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClient.java index 26b88734471..b314c365c38 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClient.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/BedrockAnthropicChatClient.java @@ -27,6 +27,8 @@ import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi; import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatRequest; import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatResponse; +import org.springframework.ai.bedrock.anthropic.metadata.BedrockAnthropicChatResponseMetadata; +import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationContext; import org.springframework.ai.chat.StreamingChatClient; import org.springframework.ai.chat.Generation; import org.springframework.ai.chat.prompt.Prompt; @@ -37,6 +39,7 @@ * generative. * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ public class BedrockAnthropicChatClient implements ChatClient, StreamingChatClient { @@ -65,9 +68,16 @@ public ChatResponse call(Prompt prompt) { AnthropicChatRequest request = createRequest(prompt); - AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request); + AmazonBedrockInvocationContext context = anthropicChatApi.chatCompletion(request); - return new ChatResponse(List.of(new Generation(response.completion()))); + AnthropicChatResponse response = context.response(); + + List generations = List.of(new Generation(response.completion())); + + BedrockAnthropicChatResponseMetadata chatResponseMetadata = BedrockAnthropicChatResponseMetadata.from(response, + context.metadata()); + + return new ChatResponse(generations, chatResponseMetadata); } @Override diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java index 2437b35eebb..59509d77895 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApi.java @@ -238,7 +238,7 @@ public String id() { } @Override - public AnthropicChatResponse chatCompletion(AnthropicChatRequest anthropicRequest) { + public AmazonBedrockInvocationContext chatCompletion(AnthropicChatRequest anthropicRequest) { Assert.notNull(anthropicRequest, "'anthropicRequest' must not be null"); return this.internalInvocation(anthropicRequest, AnthropicChatResponse.class); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/metadata/BedrockAnthropicChatResponseMetadata.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/metadata/BedrockAnthropicChatResponseMetadata.java new file mode 100644 index 00000000000..0951e86707a --- /dev/null +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/metadata/BedrockAnthropicChatResponseMetadata.java @@ -0,0 +1,85 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.anthropic.metadata; + +import org.springframework.ai.bedrock.BedrockUsage; +import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatResponse; +import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationMetadata; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.PromptMetadata; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.util.Assert; + +/** + * {@link ChatResponseMetadata} implementation for + * {@literal Amazon Bedrock Anthropic Chat Model}. + * + * @author Wei Jiang + * @see ChatResponseMetadata + * @since 0.8.1 + */ +public class BedrockAnthropicChatResponseMetadata implements ChatResponseMetadata { + + protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, latency: %3$s, usage: %4$s, rateLimit: %5$s }"; + + public static BedrockAnthropicChatResponseMetadata from(AnthropicChatResponse response, + AmazonBedrockInvocationMetadata invocationMetadata) { + Assert.notNull(invocationMetadata, "Bedrock invocation metadata must not be null"); + + BedrockUsage usage = BedrockUsage.from(invocationMetadata); + + BedrockAnthropicChatResponseMetadata chatResponseMetadata = new BedrockAnthropicChatResponseMetadata( + invocationMetadata.awsRequestId(), invocationMetadata.invocationLatency(), usage); + return chatResponseMetadata; + } + + private final String id; + + private Long invocationLatency; + + private final Usage usage; + + protected BedrockAnthropicChatResponseMetadata(String id, Long invocationLatency, BedrockUsage usage) { + this.id = id; + this.invocationLatency = invocationLatency; + this.usage = usage; + } + + public String getId() { + return this.id; + } + + public Long getInvocationLatency() { + return this.invocationLatency; + } + + @Override + public Usage getUsage() { + return this.usage; + } + + @Override + public PromptMetadata getPromptMetadata() { + return PromptMetadata.empty(); + } + + @Override + public String toString() { + return AI_METADATA_STRING.formatted(getClass().getTypeName(), getId(), getInvocationLatency(), getUsage(), + getRateLimit()); + } + +} diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatClient.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatClient.java index 12dba850c41..89a24c5e5bf 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatClient.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatClient.java @@ -20,6 +20,8 @@ import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatResponse; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatStreamingResponse.StreamingType; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.MediaContent; +import org.springframework.ai.bedrock.anthropic3.metadata.BedrockAnthropic3ChatResponseMetadata; +import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationContext; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage.Role; import org.springframework.ai.chat.ChatClient; @@ -48,6 +50,7 @@ * * @author Ben Middleton * @author Christian Tzolov + * @author Wei Jiang * @since 1.0.0 */ public class BedrockAnthropic3ChatClient implements ChatClient, StreamingChatClient { @@ -76,9 +79,16 @@ public ChatResponse call(Prompt prompt) { AnthropicChatRequest request = createRequest(prompt); - AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request); + AmazonBedrockInvocationContext context = this.anthropicChatApi.chatCompletion(request); - return new ChatResponse(List.of(new Generation(response.content().get(0).text()))); + AnthropicChatResponse response = context.response(); + + List generations = List.of(new Generation(response.content().get(0).text())); + + BedrockAnthropic3ChatResponseMetadata chatResponseMetadata = BedrockAnthropic3ChatResponseMetadata + .from(response, context.metadata()); + + return new ChatResponse(generations, chatResponseMetadata); } @Override diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java index e76bfcbeff3..eabeb3d9bbc 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApi.java @@ -459,7 +459,7 @@ public String id() { } @Override - public AnthropicChatResponse chatCompletion(AnthropicChatRequest anthropicRequest) { + public AmazonBedrockInvocationContext chatCompletion(AnthropicChatRequest anthropicRequest) { Assert.notNull(anthropicRequest, "'anthropicRequest' must not be null"); return this.internalInvocation(anthropicRequest, AnthropicChatResponse.class); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/metadata/BedrockAnthropic3ChatResponseMetadata.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/metadata/BedrockAnthropic3ChatResponseMetadata.java new file mode 100644 index 00000000000..08bf1b639cd --- /dev/null +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/metadata/BedrockAnthropic3ChatResponseMetadata.java @@ -0,0 +1,85 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.anthropic3.metadata; + +import org.springframework.ai.bedrock.BedrockUsage; +import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatResponse; +import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationMetadata; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.PromptMetadata; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.util.Assert; + +/** + * {@link ChatResponseMetadata} implementation for + * {@literal Amazon Bedrock Anthropic Chat Model}. + * + * @author Wei Jiang + * @see ChatResponseMetadata + * @since 0.8.1 + */ +public class BedrockAnthropic3ChatResponseMetadata implements ChatResponseMetadata { + + protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, latency: %3$s, usage: %4$s, rateLimit: %5$s }"; + + public static BedrockAnthropic3ChatResponseMetadata from(AnthropicChatResponse response, + AmazonBedrockInvocationMetadata invocationMetadata) { + Assert.notNull(invocationMetadata, "Bedrock invocation metadata must not be null"); + + BedrockUsage usage = BedrockUsage.from(invocationMetadata); + + BedrockAnthropic3ChatResponseMetadata chatResponseMetadata = new BedrockAnthropic3ChatResponseMetadata( + response.id(), invocationMetadata.invocationLatency(), usage); + return chatResponseMetadata; + } + + private final String id; + + private Long invocationLatency; + + private final Usage usage; + + protected BedrockAnthropic3ChatResponseMetadata(String id, Long invocationLatency, BedrockUsage usage) { + this.id = id; + this.invocationLatency = invocationLatency; + this.usage = usage; + } + + public String getId() { + return this.id; + } + + public Long getInvocationLatency() { + return this.invocationLatency; + } + + @Override + public Usage getUsage() { + return this.usage; + } + + @Override + public PromptMetadata getPromptMetadata() { + return PromptMetadata.empty(); + } + + @Override + public String toString() { + return AI_METADATA_STRING.formatted(getClass().getTypeName(), getId(), getInvocationLatency(), getUsage(), + getRateLimit()); + } + +} diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java index 74f4249f04f..36d6c476de5 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java @@ -34,6 +34,7 @@ import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider; import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.http.SdkHttpResponse; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; @@ -61,6 +62,7 @@ * @see Model Parameters * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ public abstract class AbstractBedrockApi { @@ -93,7 +95,7 @@ public AbstractBedrockApi(String modelId, String region, Duration timeout) { this(modelId, ProfileCredentialsProvider.builder().build(), region, ModelOptionsUtils.OBJECT_MAPPER, timeout); } - /** + /** * Create a new AbstractBedrockApi instance using the provided credentials provider, region and object mapper. * * @param modelId The model id to use. @@ -175,13 +177,48 @@ public record AmazonBedrockInvocationMetrics( @JsonProperty("invocationLatency") Long invocationLatency) { } + /** + * Amazon Bedrock Invocation context, includes the model invocation response and metadata. + * + * @param response The model invocation response. + * @param metadata The model invocation response metadata. + */ + public record AmazonBedrockInvocationContext(T response, AmazonBedrockInvocationMetadata metadata) {} + + /** + * Amazon Bedrock Invocation metadata. + * + * @param awsRequestId The AWS Bedrock invocation request id. + * @param inputTokenCount The input token count during this invocation. + * @param outputTokenCount The output token count during this invocation. + * @param invocationLatency The invocation latency in millisecondes during this invocation. + */ + public record AmazonBedrockInvocationMetadata(String awsRequestId, Long inputTokenCount, Long outputTokenCount, Long invocationLatency) { + private static final String INPUT_TOKEN_COUNT = "X-Amzn-Bedrock-Input-Token-Count"; + private static final String OUTPUT_TOKEN_COUNT = "X-Amzn-Bedrock-Output-Token-Count"; + private static final String INVOCATION_LATENCY = "X-Amzn-Bedrock-Invocation-Latency"; + private static final String AWS_REQUEST_ID = "x-amzn-RequestId"; + + public static AmazonBedrockInvocationMetadata from(InvokeModelResponse response) { + SdkHttpResponse sdkHttpResponse = response.sdkHttpResponse(); + + String awsRequestId = sdkHttpResponse.firstMatchingHeader(AWS_REQUEST_ID).orElse(null); + Long inputTokenCount = sdkHttpResponse.firstMatchingHeader(INPUT_TOKEN_COUNT).map(Long::valueOf).orElse(0L); + Long outputTokenCount = sdkHttpResponse.firstMatchingHeader(OUTPUT_TOKEN_COUNT).map(Long::valueOf).orElse(0L); + Long invocationLatency = sdkHttpResponse.firstMatchingHeader(INVOCATION_LATENCY).map(Long::valueOf).orElse(0L); + + return new AmazonBedrockInvocationMetadata(awsRequestId, inputTokenCount, outputTokenCount, invocationLatency); + } + + } + /** * Compute the embedding for the given text. * * @param request The embedding request. - * @return Returns the embedding response. + * @return Returns the embedding response context. */ - protected O embedding(I request) { + protected AmazonBedrockInvocationContext embedding(I request) { throw new UnsupportedOperationException("Embedding is not supported for this model: " + this.modelId); } @@ -189,9 +226,9 @@ protected O embedding(I request) { * Chat completion invocation. * * @param request The chat completion request. - * @return The chat completion response. + * @return The chat completion response context. */ - protected O chatCompletion(I request) { + protected AmazonBedrockInvocationContext chatCompletion(I request) { throw new UnsupportedOperationException("Chat completion is not supported for this model: " + this.modelId); } @@ -215,13 +252,12 @@ protected Flux chatCompletionStream(I request) { * @param request Model invocation request. * @param clazz The response class type * @return The model invocation response. - * */ - protected O internalInvocation(I request, Class clazz) { + protected AmazonBedrockInvocationContext internalInvocation(I request, Class clazz) { SdkBytes body; try { - body = SdkBytes.fromUtf8String(new ObjectMapper().writeValueAsString(request)); + body = SdkBytes.fromUtf8String(objectMapper.writeValueAsString(request)); } catch (JsonProcessingException e) { throw new IllegalArgumentException("Invalid JSON format for the input request: " + request, e); @@ -237,7 +273,11 @@ protected O internalInvocation(I request, Class clazz) { String responseBody = response.body().asString(StandardCharsets.UTF_8); try { - return this.objectMapper.readValue(responseBody, clazz); + O invokeResponse = this.objectMapper.readValue(responseBody, clazz); + + AmazonBedrockInvocationMetadata metadata = AmazonBedrockInvocationMetadata.from(response); + + return new AmazonBedrockInvocationContext<>(invokeResponse, metadata); } catch (JsonProcessingException | UncheckedIOException e) { diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClient.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClient.java index 3ff2b2ceecb..65ee1ceea2b 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClient.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereChatClient.java @@ -21,9 +21,11 @@ import org.springframework.ai.bedrock.BedrockUsage; import org.springframework.ai.bedrock.MessageToPromptConverter; +import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationContext; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatResponse; +import org.springframework.ai.bedrock.cohere.metadata.CohereChatResponseMetadata; import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.ChatResponse; @@ -37,6 +39,7 @@ /** * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ public class BedrockCohereChatClient implements ChatClient, StreamingChatClient { @@ -59,12 +62,19 @@ public BedrockCohereChatClient(CohereChatBedrockApi chatApi, BedrockCohereChatOp @Override public ChatResponse call(Prompt prompt) { - CohereChatResponse response = this.chatApi.chatCompletion(this.createRequest(prompt, false)); + CohereChatRequest request = this.createRequest(prompt, false); + + AmazonBedrockInvocationContext context = chatApi.chatCompletion(request); + + CohereChatResponse response = context.response(); + List generations = response.generations().stream().map(g -> { return new Generation(g.text()); }).toList(); - return new ChatResponse(generations); + CohereChatResponseMetadata chatResponseMetadata = CohereChatResponseMetadata.from(response, context.metadata()); + + return new ChatResponse(generations, chatResponseMetadata); } @Override diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingClient.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingClient.java index 2c0145059f9..8e400457d67 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingClient.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/BedrockCohereEmbeddingClient.java @@ -104,7 +104,7 @@ public EmbeddingResponse call(EmbeddingRequest request) { var apiRequest = new CohereEmbeddingRequest(request.getInstructions(), optionsToUse.getInputType(), optionsToUse.getTruncate()); - CohereEmbeddingResponse apiResponse = this.embeddingApi.embedding(apiRequest); + CohereEmbeddingResponse apiResponse = this.embeddingApi.embedding(apiRequest).response(); var indexCounter = new AtomicInteger(0); List embeddings = apiResponse.embeddings() .stream() diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java index b3b02b6993f..f35b88efa22 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApi.java @@ -377,7 +377,7 @@ public String id() { } @Override - public CohereChatResponse chatCompletion(CohereChatRequest request) { + public AmazonBedrockInvocationContext chatCompletion(CohereChatRequest request) { Assert.isTrue(!request.stream(), "The request must be configured to return the complete response!"); return this.internalInvocation(request, CohereChatResponse.class); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java index 7d0fa442cde..b982023b179 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApi.java @@ -202,7 +202,7 @@ public String id() { } @Override - public CohereEmbeddingResponse embedding(CohereEmbeddingRequest request) { + public AmazonBedrockInvocationContext embedding(CohereEmbeddingRequest request) { return this.internalInvocation(request, CohereEmbeddingResponse.class); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/metadata/CohereChatResponseMetadata.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/metadata/CohereChatResponseMetadata.java new file mode 100644 index 00000000000..627d9d273e0 --- /dev/null +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/cohere/metadata/CohereChatResponseMetadata.java @@ -0,0 +1,85 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.cohere.metadata; + +import org.springframework.ai.bedrock.BedrockUsage; +import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationMetadata; +import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatResponse; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.PromptMetadata; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.util.Assert; + +/** + * {@link ChatResponseMetadata} implementation for + * {@literal Amazon Bedrock Cohere chat model}. + * + * @author Wei Jiang + * @see ChatResponseMetadata + * @since 0.8.1 + */ +public class CohereChatResponseMetadata implements ChatResponseMetadata { + + protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, latency: %3$s, usage: %4$s, rateLimit: %5$s }"; + + public static CohereChatResponseMetadata from(CohereChatResponse response, + AmazonBedrockInvocationMetadata invocationMetadata) { + Assert.notNull(invocationMetadata, "Bedrock invocation metadata must not be null"); + + BedrockUsage usage = BedrockUsage.from(invocationMetadata); + + CohereChatResponseMetadata chatResponseMetadata = new CohereChatResponseMetadata(response.id(), + invocationMetadata.invocationLatency(), usage); + return chatResponseMetadata; + } + + private final String id; + + private Long invocationLatency; + + private final Usage usage; + + protected CohereChatResponseMetadata(String id, Long invocationLatency, BedrockUsage usage) { + this.id = id; + this.invocationLatency = invocationLatency; + this.usage = usage; + } + + public String getId() { + return this.id; + } + + public Long getInvocationLatency() { + return this.invocationLatency; + } + + @Override + public Usage getUsage() { + return this.usage; + } + + @Override + public PromptMetadata getPromptMetadata() { + return PromptMetadata.empty(); + } + + @Override + public String toString() { + return AI_METADATA_STRING.formatted(getClass().getTypeName(), getId(), getInvocationLatency(), getUsage(), + getRateLimit()); + } + +} diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatClient.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatClient.java index 7a11a25246c..614daee8ef7 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatClient.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/BedrockAi21Jurassic2ChatClient.java @@ -19,6 +19,7 @@ import org.springframework.ai.bedrock.MessageToPromptConverter; import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi; import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatRequest; +import org.springframework.ai.bedrock.jurassic2.metadata.BedrockAi21Jurassic2ChatResponseMetadata; import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; @@ -32,6 +33,7 @@ * Java {@link ChatClient} for the Bedrock Jurassic2 chat generative model. * * @author Ahmed Yousri + * @author Wei Jiang * @since 1.0.0 */ public class BedrockAi21Jurassic2ChatClient implements ChatClient { @@ -61,13 +63,20 @@ public BedrockAi21Jurassic2ChatClient(Ai21Jurassic2ChatBedrockApi chatApi) { @Override public ChatResponse call(Prompt prompt) { var request = createRequest(prompt); - var response = this.chatApi.chatCompletion(request); - return new ChatResponse(response.completions() + var context = this.chatApi.chatCompletion(request); + + var response = context.response(); + + var generations = response.completions() .stream() .map(completion -> new Generation(completion.data().text()) .withGenerationMetadata(ChatGenerationMetadata.from(completion.finishReason().reason(), null))) - .toList()); + .toList(); + + var chatResponseMetadata = BedrockAi21Jurassic2ChatResponseMetadata.from(response, context.metadata()); + + return new ChatResponse(generations, chatResponseMetadata); } private Ai21Jurassic2ChatRequest createRequest(Prompt prompt) { diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java index fa505176350..9a3a67801f4 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApi.java @@ -382,7 +382,7 @@ public String id() { } @Override - public Ai21Jurassic2ChatResponse chatCompletion(Ai21Jurassic2ChatRequest request) { + public AmazonBedrockInvocationContext chatCompletion(Ai21Jurassic2ChatRequest request) { return this.internalInvocation(request, Ai21Jurassic2ChatResponse.class); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/metadata/BedrockAi21Jurassic2ChatResponseMetadata.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/metadata/BedrockAi21Jurassic2ChatResponseMetadata.java new file mode 100644 index 00000000000..a0b35547676 --- /dev/null +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/jurassic2/metadata/BedrockAi21Jurassic2ChatResponseMetadata.java @@ -0,0 +1,85 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.jurassic2.metadata; + +import org.springframework.ai.bedrock.BedrockUsage; +import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationMetadata; +import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatResponse; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.PromptMetadata; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.util.Assert; + +/** + * {@link ChatResponseMetadata} implementation for + * {@literal Amazon Bedrock Ai21Jurassic2 chat model}. + * + * @author Wei Jiang + * @see ChatResponseMetadata + * @since 0.8.1 + */ +public class BedrockAi21Jurassic2ChatResponseMetadata implements ChatResponseMetadata { + + protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, latency: %3$s, usage: %4$s, rateLimit: %5$s }"; + + public static BedrockAi21Jurassic2ChatResponseMetadata from(Ai21Jurassic2ChatResponse response, + AmazonBedrockInvocationMetadata invocationMetadata) { + Assert.notNull(invocationMetadata, "Bedrock invocation metadata must not be null"); + + BedrockUsage usage = BedrockUsage.from(invocationMetadata); + + BedrockAi21Jurassic2ChatResponseMetadata chatResponseMetadata = new BedrockAi21Jurassic2ChatResponseMetadata( + response.id(), invocationMetadata.invocationLatency(), usage); + return chatResponseMetadata; + } + + private final String id; + + private Long invocationLatency; + + private final Usage usage; + + protected BedrockAi21Jurassic2ChatResponseMetadata(String id, Long invocationLatency, BedrockUsage usage) { + this.id = id; + this.invocationLatency = invocationLatency; + this.usage = usage; + } + + public String getId() { + return this.id; + } + + public Long getInvocationLatency() { + return this.invocationLatency; + } + + @Override + public Usage getUsage() { + return this.usage; + } + + @Override + public PromptMetadata getPromptMetadata() { + return PromptMetadata.empty(); + } + + @Override + public String toString() { + return AI_METADATA_STRING.formatted(getClass().getTypeName(), getId(), getInvocationLatency(), getUsage(), + getRateLimit()); + } + +} diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClient.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClient.java index a12fe0b6eb9..a5050a5fc3e 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClient.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/BedrockLlama2ChatClient.java @@ -23,6 +23,7 @@ import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi; import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatRequest; import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatResponse; +import org.springframework.ai.bedrock.llama2.metadata.BedrockLlama2ChatResponseMetadata; import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.ChatResponse; @@ -39,6 +40,7 @@ * generative. * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ public class BedrockLlama2ChatClient implements ChatClient, StreamingChatClient { @@ -65,10 +67,16 @@ public ChatResponse call(Prompt prompt) { var request = createRequest(prompt); - Llama2ChatResponse response = this.chatApi.chatCompletion(request); + var context = this.chatApi.chatCompletion(request); - return new ChatResponse(List.of(new Generation(response.generation()).withGenerationMetadata( - ChatGenerationMetadata.from(response.stopReason().name(), extractUsage(response))))); + Llama2ChatResponse response = context.response(); + + var generations = List.of(new Generation(response.generation()) + .withGenerationMetadata(ChatGenerationMetadata.from(response.stopReason().name(), extractUsage(response)))); + + var chatResponseMetadata = BedrockLlama2ChatResponseMetadata.from(response, context.metadata()); + + return new ChatResponse(generations, chatResponseMetadata); } @Override diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/api/Llama2ChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/api/Llama2ChatBedrockApi.java index af10d69bdfb..58278c60238 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/api/Llama2ChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/api/Llama2ChatBedrockApi.java @@ -215,7 +215,7 @@ public String id() { } @Override - public Llama2ChatResponse chatCompletion(Llama2ChatRequest request) { + public AmazonBedrockInvocationContext chatCompletion(Llama2ChatRequest request) { return this.internalInvocation(request, Llama2ChatResponse.class); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/metadata/BedrockLlama2ChatResponseMetadata.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/metadata/BedrockLlama2ChatResponseMetadata.java new file mode 100644 index 00000000000..33c46f9013d --- /dev/null +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/llama2/metadata/BedrockLlama2ChatResponseMetadata.java @@ -0,0 +1,85 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.llama2.metadata; + +import org.springframework.ai.bedrock.BedrockUsage; +import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationMetadata; +import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatResponse; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.PromptMetadata; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.util.Assert; + +/** + * {@link ChatResponseMetadata} implementation for + * {@literal Amazon Bedrock Llama2 chat model}. + * + * @author Wei Jiang + * @see ChatResponseMetadata + * @since 0.8.1 + */ +public class BedrockLlama2ChatResponseMetadata implements ChatResponseMetadata { + + protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, latency: %3$s, usage: %4$s, rateLimit: %5$s }"; + + public static BedrockLlama2ChatResponseMetadata from(Llama2ChatResponse response, + AmazonBedrockInvocationMetadata invocationMetadata) { + Assert.notNull(invocationMetadata, "Bedrock invocation metadata must not be null"); + + BedrockUsage usage = BedrockUsage.from(invocationMetadata); + + BedrockLlama2ChatResponseMetadata chatResponseMetadata = new BedrockLlama2ChatResponseMetadata( + invocationMetadata.awsRequestId(), invocationMetadata.invocationLatency(), usage); + return chatResponseMetadata; + } + + private final String id; + + private Long invocationLatency; + + private final Usage usage; + + protected BedrockLlama2ChatResponseMetadata(String id, Long invocationLatency, BedrockUsage usage) { + this.id = id; + this.invocationLatency = invocationLatency; + this.usage = usage; + } + + public String getId() { + return this.id; + } + + public Long getInvocationLatency() { + return this.invocationLatency; + } + + @Override + public Usage getUsage() { + return this.usage; + } + + @Override + public PromptMetadata getPromptMetadata() { + return PromptMetadata.empty(); + } + + @Override + public String toString() { + return AI_METADATA_STRING.formatted(getClass().getTypeName(), getId(), getInvocationLatency(), getUsage(), + getRateLimit()); + } + +} diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClient.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClient.java index e77d8277cb2..95477df8b25 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClient.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanChatClient.java @@ -20,10 +20,12 @@ import reactor.core.publisher.Flux; import org.springframework.ai.bedrock.MessageToPromptConverter; +import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationContext; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatRequest; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponse; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponseChunk; +import org.springframework.ai.bedrock.titan.metadata.BedrockTitanChatResponseMetadata; import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.ChatResponse; @@ -37,6 +39,7 @@ /** * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ public class BedrockTitanChatClient implements ChatClient, StreamingChatClient { @@ -58,12 +61,20 @@ public BedrockTitanChatClient(TitanChatBedrockApi chatApi, BedrockTitanChatOptio @Override public ChatResponse call(Prompt prompt) { - TitanChatResponse response = this.chatApi.chatCompletion(this.createRequest(prompt)); + TitanChatRequest request = this.createRequest(prompt); + + AmazonBedrockInvocationContext context = this.chatApi.chatCompletion(request); + + TitanChatResponse response = context.response(); + List generations = response.results().stream().map(result -> { return new Generation(result.outputText()); }).toList(); - return new ChatResponse(generations); + BedrockTitanChatResponseMetadata chatResponseMetadata = BedrockTitanChatResponseMetadata.from(response, + context.metadata()); + + return new ChatResponse(generations, chatResponseMetadata); } @Override diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClient.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClient.java index d48135f80ec..f0e8d05de93 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClient.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingClient.java @@ -90,7 +90,7 @@ public EmbeddingResponse call(EmbeddingRequest request) { var apiRequest = (this.inputType == InputType.IMAGE) ? new TitanEmbeddingRequest.Builder().withInputImage(inputContent).build() : new TitanEmbeddingRequest.Builder().withInputText(inputContent).build(); - TitanEmbeddingResponse response = this.embeddingApi.embedding(apiRequest); + TitanEmbeddingResponse response = this.embeddingApi.embedding(apiRequest).response(); embeddingList.add(response.embedding()); } var indexCounter = new AtomicInteger(0); diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java index 498b34bf3d8..63d89801428 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApi.java @@ -38,6 +38,7 @@ * https://docs.aws.amazon.com/bedrock/latest/userguide/titan-text-models.html * * @author Christian Tzolov + * @author Wei Jiang * @since 0.8.0 */ // @formatter:off @@ -276,7 +277,7 @@ public String id() { } @Override - public TitanChatResponse chatCompletion(TitanChatRequest request) { + public AmazonBedrockInvocationContext chatCompletion(TitanChatRequest request) { return this.internalInvocation(request, TitanChatResponse.class); } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java index 9c1dcb3b267..49e9f1338b4 100644 --- a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApi.java @@ -154,7 +154,7 @@ public String id() { } @Override - public TitanEmbeddingResponse embedding(TitanEmbeddingRequest request) { + public AmazonBedrockInvocationContext embedding(TitanEmbeddingRequest request) { return this.internalInvocation(request, TitanEmbeddingResponse.class); } } diff --git a/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/metadata/BedrockTitanChatResponseMetadata.java b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/metadata/BedrockTitanChatResponseMetadata.java new file mode 100644 index 00000000000..da487d09ade --- /dev/null +++ b/models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/metadata/BedrockTitanChatResponseMetadata.java @@ -0,0 +1,85 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.titan.metadata; + +import org.springframework.ai.bedrock.BedrockUsage; +import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationMetadata; +import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponse; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.PromptMetadata; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.util.Assert; + +/** + * {@link ChatResponseMetadata} implementation for + * {@literal Amazon Bedrock Titan chat model}. + * + * @author Wei Jiang + * @see ChatResponseMetadata + * @since 0.8.1 + */ +public class BedrockTitanChatResponseMetadata implements ChatResponseMetadata { + + protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, latency: %3$s, usage: %4$s, rateLimit: %5$s }"; + + public static BedrockTitanChatResponseMetadata from(TitanChatResponse response, + AmazonBedrockInvocationMetadata invocationMetadata) { + Assert.notNull(invocationMetadata, "Bedrock invocation metadata must not be null"); + + BedrockUsage usage = BedrockUsage.from(invocationMetadata); + + BedrockTitanChatResponseMetadata chatResponseMetadata = new BedrockTitanChatResponseMetadata( + invocationMetadata.awsRequestId(), invocationMetadata.invocationLatency(), usage); + return chatResponseMetadata; + } + + private final String id; + + private Long invocationLatency; + + private final Usage usage; + + protected BedrockTitanChatResponseMetadata(String id, Long invocationLatency, BedrockUsage usage) { + this.id = id; + this.invocationLatency = invocationLatency; + this.usage = usage; + } + + public String getId() { + return this.id; + } + + public Long getInvocationLatency() { + return this.invocationLatency; + } + + @Override + public Usage getUsage() { + return this.usage; + } + + @Override + public PromptMetadata getPromptMetadata() { + return PromptMetadata.empty(); + } + + @Override + public String toString() { + return AI_METADATA_STRING.formatted(getClass().getTypeName(), getId(), getInvocationLatency(), getUsage(), + getRateLimit()); + } + +} diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApiIT.java index 334efa48ff5..3d93f0ed43d 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/api/AnthropicChatBedrockApiIT.java @@ -30,6 +30,8 @@ import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatRequest; import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatResponse; +import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationContext; +import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationMetadata; import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatModel; import static org.assertj.core.api.Assertions.assertThat;; @@ -57,7 +59,14 @@ public void chatCompletion() { .withTopK(10) .build(); - AnthropicChatResponse response = anthropicChatApi.chatCompletion(request); + AmazonBedrockInvocationContext context = anthropicChatApi.chatCompletion(request); + assertThat(context).isNotNull(); + + AmazonBedrockInvocationMetadata metadata = context.metadata(); + assertThat(metadata).isNotNull(); + + AnthropicChatResponse response = context.response(); + assertThat(response).isNotNull(); System.out.println(response.completion()); assertThat(response).isNotNull(); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/metadata/BedrockAnthropicChatResponseMetadataT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/metadata/BedrockAnthropicChatResponseMetadataT.java new file mode 100644 index 00000000000..a7f24ead218 --- /dev/null +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic/metadata/BedrockAnthropicChatResponseMetadataT.java @@ -0,0 +1,85 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.anthropic.metadata; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.time.Duration; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.bedrock.anthropic.BedrockAnthropicChatClient; +import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi; +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +/** + * @author Wei Jiang + */ +@SpringBootTest +@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") +public class BedrockAnthropicChatResponseMetadataT { + + @Autowired + private BedrockAnthropicChatClient client; + + @Test + public void chatCompletion() { + Prompt prompt = new Prompt("Can I fly like a bird?"); + + ChatResponse response = this.client.call(prompt); + + assertThat(response).isNotNull(); + assertThat(response.getMetadata()).isNotNull(); + assertThat(response.getMetadata()).isInstanceOf(BedrockAnthropicChatResponseMetadata.class); + + BedrockAnthropicChatResponseMetadata metadata = (BedrockAnthropicChatResponseMetadata) response.getMetadata(); + assertThat(metadata.getId()).isNotNull(); + assertThat(metadata.getInvocationLatency()).isNotNull(); + + assertThat(response.getMetadata().getUsage()).isNotNull(); + assertThat(response.getMetadata().getUsage().getGenerationTokens()).isGreaterThan(0); + assertThat(response.getMetadata().getUsage().getPromptTokens()).isGreaterThan(0); + } + + @SpringBootConfiguration + public static class TestConfiguration { + + @Bean + public AnthropicChatBedrockApi anthropicApi() { + return new AnthropicChatBedrockApi(AnthropicChatBedrockApi.AnthropicChatModel.CLAUDE_V2.id(), + EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(), + Duration.ofMinutes(2)); + } + + @Bean + public BedrockAnthropicChatClient anthropicChatClient(AnthropicChatBedrockApi anthropicApi) { + return new BedrockAnthropicChatClient(anthropicApi); + } + + } + +} diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApiIT.java index 15ab3dd0f5f..715678f1bae 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/api/Anthropic3ChatBedrockApiIT.java @@ -25,6 +25,8 @@ import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatResponse; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatStreamingResponse.StreamingType; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.MediaContent; +import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationContext; +import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationMetadata; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage; import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage.Role; import reactor.core.publisher.Flux; @@ -40,6 +42,7 @@ /** * @author Ben Middleton + * @author Wei Jiang */ @EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") @EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") @@ -63,11 +66,17 @@ public void chatCompletion() { .withAnthropicVersion(DEFAULT_ANTHROPIC_VERSION) .build(); - AnthropicChatResponse response = anthropicChatApi.chatCompletion(request); + AmazonBedrockInvocationContext context = anthropicChatApi.chatCompletion(request); + assertThat(context).isNotNull(); - logger.info("" + response.content()); + AmazonBedrockInvocationMetadata metadata = context.metadata(); + assertThat(metadata).isNotNull(); + AnthropicChatResponse response = context.response(); assertThat(response).isNotNull(); + + logger.info("" + response.content()); + assertThat(response.content().get(0).text()).isNotEmpty(); assertThat(response.content().get(0).text()).contains("Blackbeard"); assertThat(response.stopReason()).isEqualTo("end_turn"); @@ -103,7 +112,14 @@ public void chatMultiCompletion() { .withAnthropicVersion(DEFAULT_ANTHROPIC_VERSION) .build(); - AnthropicChatResponse response = anthropicChatApi.chatCompletion(request); + AmazonBedrockInvocationContext context = anthropicChatApi.chatCompletion(request); + assertThat(context).isNotNull(); + + AmazonBedrockInvocationMetadata metadata = context.metadata(); + assertThat(metadata).isNotNull(); + + AnthropicChatResponse response = context.response(); + assertThat(response).isNotNull(); logger.info("" + response.content()); assertThat(response).isNotNull(); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/metadata/BedrockAnthropic3ChatResponseMetadataT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/metadata/BedrockAnthropic3ChatResponseMetadataT.java new file mode 100644 index 00000000000..2aff44475af --- /dev/null +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/anthropic3/metadata/BedrockAnthropic3ChatResponseMetadataT.java @@ -0,0 +1,85 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.anthropic3.metadata; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.time.Duration; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.bedrock.anthropic3.BedrockAnthropic3ChatClient; +import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi; +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +/** + * @author Wei Jiang + */ +@SpringBootTest +@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") +public class BedrockAnthropic3ChatResponseMetadataT { + + @Autowired + private BedrockAnthropic3ChatClient client; + + @Test + public void chatCompletion() { + Prompt prompt = new Prompt("Can I fly like a bird?"); + + ChatResponse response = this.client.call(prompt); + + assertThat(response).isNotNull(); + assertThat(response.getMetadata()).isNotNull(); + assertThat(response.getMetadata()).isInstanceOf(BedrockAnthropic3ChatResponseMetadata.class); + + BedrockAnthropic3ChatResponseMetadata metadata = (BedrockAnthropic3ChatResponseMetadata) response.getMetadata(); + assertThat(metadata.getId()).isNotNull(); + assertThat(metadata.getInvocationLatency()).isNotNull(); + + assertThat(response.getMetadata().getUsage()).isNotNull(); + assertThat(response.getMetadata().getUsage().getGenerationTokens()).isGreaterThan(0); + assertThat(response.getMetadata().getUsage().getPromptTokens()).isGreaterThan(0); + } + + @SpringBootConfiguration + public static class TestConfiguration { + + @Bean + public Anthropic3ChatBedrockApi anthropicApi() { + return new Anthropic3ChatBedrockApi(Anthropic3ChatBedrockApi.AnthropicChatModel.CLAUDE_V3_SONNET.id(), + EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(), + Duration.ofMinutes(5)); + } + + @Bean + public BedrockAnthropic3ChatClient anthropicChatClient(Anthropic3ChatBedrockApi anthropicApi) { + return new BedrockAnthropic3ChatClient(anthropicApi); + } + + } + +} diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApiIT.java index 540a6bd2bf7..517c141d56d 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereChatBedrockApiIT.java @@ -23,6 +23,8 @@ import reactor.core.publisher.Flux; import software.amazon.awssdk.regions.Region; +import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationContext; +import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationMetadata; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatModel; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest; import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest.Truncate; @@ -83,7 +85,14 @@ public void chatCompletion() { .withTruncate(Truncate.NONE) .build(); - CohereChatResponse response = cohereChatApi.chatCompletion(request); + AmazonBedrockInvocationContext context = cohereChatApi.chatCompletion(request); + assertThat(context).isNotNull(); + + AmazonBedrockInvocationMetadata metadata = context.metadata(); + assertThat(metadata).isNotNull(); + + CohereChatResponse response = context.response(); + assertThat(response).isNotNull(); assertThat(response).isNotNull(); assertThat(response.prompt()).isEqualTo(request.prompt()); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApiIT.java index f96269fed11..d6d6e05938d 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/api/CohereEmbeddingBedrockApiIT.java @@ -24,6 +24,8 @@ import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; +import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationContext; +import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationMetadata; import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi.CohereEmbeddingModel; import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi.CohereEmbeddingRequest; import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi.CohereEmbeddingResponse; @@ -48,7 +50,14 @@ public void embedText() { List.of("I like to eat apples", "I like to eat oranges"), CohereEmbeddingRequest.InputType.SEARCH_DOCUMENT, CohereEmbeddingRequest.Truncate.NONE); - CohereEmbeddingResponse response = api.embedding(request); + AmazonBedrockInvocationContext context = api.embedding(request); + assertThat(context).isNotNull(); + + AmazonBedrockInvocationMetadata metadata = context.metadata(); + assertThat(metadata).isNotNull(); + + CohereEmbeddingResponse response = context.response(); + assertThat(response).isNotNull(); assertThat(response).isNotNull(); assertThat(response.texts()).isEqualTo(request.texts()); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/metadata/BedrockCohereChatResponseMetadataT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/metadata/BedrockCohereChatResponseMetadataT.java new file mode 100644 index 00000000000..34faf7a2a97 --- /dev/null +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/cohere/metadata/BedrockCohereChatResponseMetadataT.java @@ -0,0 +1,86 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.cohere.metadata; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.time.Duration; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.bedrock.cohere.BedrockCohereChatClient; +import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi; +import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatModel; +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +/** + * @author Wei Jiang + */ +@SpringBootTest +@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") +public class BedrockCohereChatResponseMetadataT { + + @Autowired + private BedrockCohereChatClient client; + + @Test + public void chatCompletion() { + Prompt prompt = new Prompt("Can I fly like a bird?"); + + ChatResponse response = this.client.call(prompt); + + assertThat(response).isNotNull(); + assertThat(response.getMetadata()).isNotNull(); + assertThat(response.getMetadata()).isInstanceOf(CohereChatResponseMetadata.class); + + CohereChatResponseMetadata metadata = (CohereChatResponseMetadata) response.getMetadata(); + assertThat(metadata.getId()).isNotNull(); + assertThat(metadata.getInvocationLatency()).isNotNull(); + + assertThat(response.getMetadata().getUsage()).isNotNull(); + assertThat(response.getMetadata().getUsage().getGenerationTokens()).isGreaterThan(0); + assertThat(response.getMetadata().getUsage().getPromptTokens()).isGreaterThan(0); + } + + @SpringBootConfiguration + public static class TestConfiguration { + + @Bean + public CohereChatBedrockApi cohereApi() { + return new CohereChatBedrockApi(CohereChatModel.COHERE_COMMAND_V14.id(), + EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(), + Duration.ofMinutes(2)); + } + + @Bean + public BedrockCohereChatClient cohereChatClient(CohereChatBedrockApi cohereApi) { + return new BedrockCohereChatClient(cohereApi); + } + + } + +} diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApiIT.java index 8525471d144..390d2ff1dd3 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/api/Ai21Jurassic2ChatBedrockApiIT.java @@ -22,6 +22,8 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import software.amazon.awssdk.regions.Region; +import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationContext; +import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationMetadata; import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatModel; import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatRequest; import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatResponse; @@ -46,7 +48,14 @@ public void chatCompletion() { new Ai21Jurassic2ChatRequest.FloatScalePenalty(0.5f, true, true, true, true, true), new Ai21Jurassic2ChatRequest.IntegerScalePenalty(1, true, true, true, true, true)); - Ai21Jurassic2ChatResponse response = api.chatCompletion(request); + AmazonBedrockInvocationContext context = api.chatCompletion(request); + assertThat(context).isNotNull(); + + AmazonBedrockInvocationMetadata metadata = context.metadata(); + assertThat(metadata).isNotNull(); + + Ai21Jurassic2ChatResponse response = context.response(); + assertThat(response).isNotNull(); assertThat(response).isNotNull(); assertThat(response.completions()).isNotEmpty(); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/metadata/BedrockAi21Jurassic2ChatResponseMetadataT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/metadata/BedrockAi21Jurassic2ChatResponseMetadataT.java new file mode 100644 index 00000000000..b76bfe36063 --- /dev/null +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/jurassic2/metadata/BedrockAi21Jurassic2ChatResponseMetadataT.java @@ -0,0 +1,94 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.jurassic2.metadata; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.time.Duration; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.bedrock.jurassic2.BedrockAi21Jurassic2ChatClient; +import org.springframework.ai.bedrock.jurassic2.BedrockAi21Jurassic2ChatOptions; +import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi; +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +/** + * @author Wei Jiang + */ +@SpringBootTest +@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") +public class BedrockAi21Jurassic2ChatResponseMetadataT { + + @Autowired + private BedrockAi21Jurassic2ChatClient client; + + @Test + public void chatCompletion() { + Prompt prompt = new Prompt("Can I fly like a bird?"); + + ChatResponse response = this.client.call(prompt); + + assertThat(response).isNotNull(); + assertThat(response.getMetadata()).isNotNull(); + assertThat(response.getMetadata()).isInstanceOf(BedrockAi21Jurassic2ChatResponseMetadata.class); + + BedrockAi21Jurassic2ChatResponseMetadata metadata = (BedrockAi21Jurassic2ChatResponseMetadata) response + .getMetadata(); + assertThat(metadata.getId()).isNotNull(); + assertThat(metadata.getInvocationLatency()).isNotNull(); + + assertThat(response.getMetadata().getUsage()).isNotNull(); + assertThat(response.getMetadata().getUsage().getGenerationTokens()).isGreaterThan(0); + assertThat(response.getMetadata().getUsage().getPromptTokens()).isGreaterThan(0); + } + + @SpringBootConfiguration + public static class TestConfiguration { + + @Bean + public Ai21Jurassic2ChatBedrockApi jurassic2ChatBedrockApi() { + return new Ai21Jurassic2ChatBedrockApi( + Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatModel.AI21_J2_MID_V1.id(), + EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(), + Duration.ofMinutes(2)); + } + + @Bean + public BedrockAi21Jurassic2ChatClient bedrockAi21Jurassic2ChatClient( + Ai21Jurassic2ChatBedrockApi jurassic2ChatBedrockApi) { + return new BedrockAi21Jurassic2ChatClient(jurassic2ChatBedrockApi, + BedrockAi21Jurassic2ChatOptions.builder() + .withTemperature(0.5f) + .withMaxTokens(100) + .withTopP(0.9f) + .build()); + } + + } + +} diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/api/Llama2ChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/api/Llama2ChatBedrockApiIT.java index dc97d8e7b8f..ed299cf4f3f 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/api/Llama2ChatBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/api/Llama2ChatBedrockApiIT.java @@ -23,6 +23,8 @@ import reactor.core.publisher.Flux; import software.amazon.awssdk.regions.Region; +import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationContext; +import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationMetadata; import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatModel; import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatRequest; import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatResponse; @@ -31,6 +33,7 @@ /** * @author Christian Tzolov + * @author Wei Jiang */ @EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") @EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") @@ -48,7 +51,14 @@ public void chatCompletion() { .withMaxGenLen(20) .build(); - Llama2ChatResponse response = llama2ChatApi.chatCompletion(request); + AmazonBedrockInvocationContext context = llama2ChatApi.chatCompletion(request); + assertThat(context).isNotNull(); + + AmazonBedrockInvocationMetadata metadata = context.metadata(); + assertThat(metadata).isNotNull(); + + Llama2ChatResponse response = context.response(); + assertThat(response).isNotNull(); System.out.println(response.generation()); assertThat(response).isNotNull(); diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/metadata/Llama2ChatResponseMetadataT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/metadata/Llama2ChatResponseMetadataT.java new file mode 100644 index 00000000000..aee07a59838 --- /dev/null +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/llama2/metadata/Llama2ChatResponseMetadataT.java @@ -0,0 +1,88 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.llama2.metadata; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.time.Duration; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.bedrock.llama2.BedrockLlama2ChatClient; +import org.springframework.ai.bedrock.llama2.BedrockLlama2ChatOptions; +import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi; +import org.springframework.ai.bedrock.llama2.api.Llama2ChatBedrockApi.Llama2ChatModel; +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +/** + * @author Wei Jiang + */ +@SpringBootTest +@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") +public class Llama2ChatResponseMetadataT { + + @Autowired + private BedrockLlama2ChatClient client; + + @Test + public void chatCompletion() { + Prompt prompt = new Prompt("Can I fly like a bird?"); + + ChatResponse response = this.client.call(prompt); + + assertThat(response).isNotNull(); + assertThat(response.getMetadata()).isNotNull(); + assertThat(response.getMetadata()).isInstanceOf(BedrockLlama2ChatResponseMetadata.class); + + BedrockLlama2ChatResponseMetadata metadata = (BedrockLlama2ChatResponseMetadata) response.getMetadata(); + assertThat(metadata.getId()).isNotNull(); + assertThat(metadata.getInvocationLatency()).isNotNull(); + + assertThat(response.getMetadata().getUsage()).isNotNull(); + assertThat(response.getMetadata().getUsage().getGenerationTokens()).isGreaterThan(0); + assertThat(response.getMetadata().getUsage().getPromptTokens()).isGreaterThan(0); + } + + @SpringBootConfiguration + public static class TestConfiguration { + + @Bean + public Llama2ChatBedrockApi llama2Api() { + return new Llama2ChatBedrockApi(Llama2ChatModel.LLAMA2_13B_CHAT_V1.id(), + EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(), + Duration.ofMinutes(2)); + } + + @Bean + public BedrockLlama2ChatClient llama2ChatClient(Llama2ChatBedrockApi llama2Api) { + return new BedrockLlama2ChatClient(llama2Api, + BedrockLlama2ChatOptions.builder().withTemperature(0.5f).withMaxGenLen(100).withTopP(0.9f).build()); + } + + } + +} diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApiIT.java index e7bb1f8bff3..85c5c837848 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/api/TitanChatBedrockApiIT.java @@ -24,6 +24,8 @@ import reactor.core.publisher.Flux; import software.amazon.awssdk.regions.Region; +import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationContext; +import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationMetadata; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatModel; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatRequest; import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponse; @@ -33,6 +35,7 @@ /** * @author Christian Tzolov + * @author Wei Jiang */ @EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") @EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") @@ -50,7 +53,15 @@ public class TitanChatBedrockApiIT { @Test public void chatCompletion() { - TitanChatResponse response = titanBedrockApi.chatCompletion(titanChatRequest); + AmazonBedrockInvocationContext context = titanBedrockApi.chatCompletion(titanChatRequest); + assertThat(context).isNotNull(); + + AmazonBedrockInvocationMetadata metadata = context.metadata(); + assertThat(metadata).isNotNull(); + + TitanChatResponse response = context.response(); + assertThat(response).isNotNull(); + assertThat(response.results()).hasSize(1); assertThat(response.results().get(0).outputText()).contains("Blackbeard"); } diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApiIT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApiIT.java index a666793e0c7..37319e30cf6 100644 --- a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApiIT.java +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/api/TitanEmbeddingBedrockApiIT.java @@ -21,32 +21,46 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + import software.amazon.awssdk.regions.Region; +import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationContext; +import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationMetadata; import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingModel; import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingRequest; import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingResponse; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; import org.springframework.core.io.DefaultResourceLoader; import static org.assertj.core.api.Assertions.assertThat; /** * @author Christian Tzolov + * @author Wei Jiang */ +@SpringBootTest @EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") @EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") public class TitanEmbeddingBedrockApiIT { + @Autowired + private TitanEmbeddingBedrockApi api; + @Test public void embedText() { - TitanEmbeddingBedrockApi titanEmbedApi = new TitanEmbeddingBedrockApi( - TitanEmbeddingModel.TITAN_EMBED_TEXT_V1.id(), Region.US_EAST_1.id(), Duration.ofMinutes(2)); - TitanEmbeddingRequest request = TitanEmbeddingRequest.builder().withInputText("I like to eat apples.").build(); - TitanEmbeddingResponse response = titanEmbedApi.embedding(request); + AmazonBedrockInvocationContext context = api.embedding(request); + assertThat(context).isNotNull(); + AmazonBedrockInvocationMetadata metadata = context.metadata(); + assertThat(metadata).isNotNull(); + + TitanEmbeddingResponse response = context.response(); assertThat(response).isNotNull(); assertThat(response.inputTextTokenCount()).isEqualTo(6); assertThat(response.embedding()).hasSize(1536); @@ -55,9 +69,6 @@ public void embedText() { @Test public void embedImage() throws IOException { - TitanEmbeddingBedrockApi titanEmbedApi = new TitanEmbeddingBedrockApi( - TitanEmbeddingModel.TITAN_EMBED_IMAGE_V1.id(), Region.US_EAST_1.id(), Duration.ofMinutes(2)); - byte[] image = new DefaultResourceLoader().getResource("classpath:/spring_framework.png") .getContentAsByteArray(); @@ -66,11 +77,29 @@ public void embedImage() throws IOException { TitanEmbeddingRequest request = TitanEmbeddingRequest.builder().withInputImage(imageBase64).build(); - TitanEmbeddingResponse response = titanEmbedApi.embedding(request); + AmazonBedrockInvocationContext context = api.embedding(request); + assertThat(context).isNotNull(); + + AmazonBedrockInvocationMetadata metadata = context.metadata(); + assertThat(metadata).isNotNull(); + + TitanEmbeddingResponse response = context.response(); + assertThat(response).isNotNull(); assertThat(response).isNotNull(); assertThat(response.inputTextTokenCount()).isEqualTo(0); // e.g. image input assertThat(response.embedding()).hasSize(1024); } + @SpringBootConfiguration + public static class TestConfiguration { + + @Bean + public TitanEmbeddingBedrockApi cohereEmbeddingApi() { + return new TitanEmbeddingBedrockApi(TitanEmbeddingModel.TITAN_EMBED_TEXT_V1.id(), Region.US_EAST_1.id(), + Duration.ofMinutes(2)); + } + + } + } diff --git a/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/metadata/BedrockTitanChatResponseMetadataT.java b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/metadata/BedrockTitanChatResponseMetadataT.java new file mode 100644 index 00000000000..80b85b939ac --- /dev/null +++ b/models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/metadata/BedrockTitanChatResponseMetadataT.java @@ -0,0 +1,86 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.bedrock.titan.metadata; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.time.Duration; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.bedrock.titan.BedrockTitanChatClient; +import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi; +import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatModel; +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; +import software.amazon.awssdk.regions.Region; + +/** + * @author Wei Jiang + */ +@SpringBootTest +@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*") +@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*") +public class BedrockTitanChatResponseMetadataT { + + @Autowired + private BedrockTitanChatClient client; + + @Test + public void chatCompletion() { + Prompt prompt = new Prompt("Can I fly like a bird?"); + + ChatResponse response = this.client.call(prompt); + + assertThat(response).isNotNull(); + assertThat(response.getMetadata()).isNotNull(); + assertThat(response.getMetadata()).isInstanceOf(BedrockTitanChatResponseMetadata.class); + + BedrockTitanChatResponseMetadata metadata = (BedrockTitanChatResponseMetadata) response.getMetadata(); + assertThat(metadata.getId()).isNotNull(); + assertThat(metadata.getInvocationLatency()).isNotNull(); + + assertThat(response.getMetadata().getUsage()).isNotNull(); + assertThat(response.getMetadata().getUsage().getGenerationTokens()).isGreaterThan(0); + assertThat(response.getMetadata().getUsage().getPromptTokens()).isGreaterThan(0); + } + + @SpringBootConfiguration + public static class TestConfiguration { + + @Bean + public TitanChatBedrockApi titanApi() { + return new TitanChatBedrockApi(TitanChatModel.TITAN_TEXT_EXPRESS_V1.id(), + EnvironmentVariableCredentialsProvider.create(), Region.US_EAST_1.id(), new ObjectMapper(), + Duration.ofMinutes(2)); + } + + @Bean + public BedrockTitanChatClient titanChatClient(TitanChatBedrockApi titanApi) { + return new BedrockTitanChatClient(titanApi); + } + + } + +}