diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/AzureContentFilter.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/AzureContentFilter.java index ee5ff3e5..6ad331f1 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/AzureContentFilter.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/AzureContentFilter.java @@ -1,26 +1,21 @@ package com.sap.ai.sdk.orchestration; +import static com.sap.ai.sdk.orchestration.client.model.AzureThreshold.fromValue; + import com.sap.ai.sdk.orchestration.client.model.AzureContentSafety; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; import lombok.Data; import lombok.RequiredArgsConstructor; import lombok.experimental.Accessors; -import javax.annotation.Nonnull; -import javax.annotation.Nullable; - -import static com.sap.ai.sdk.orchestration.client.model.AzureThreshold.fromValue; - @Data @Accessors(fluent = true) public class AzureContentFilter implements ContentFilter { - @Nullable - private Setting hate; - @Nullable - private Setting selfHarm; - @Nullable - private Setting sexual; - @Nullable - private Setting violence; + @Nullable private Setting hate; + @Nullable private Setting selfHarm; + @Nullable private Setting sexual; + @Nullable private Setting violence; @RequiredArgsConstructor public enum Setting { diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/DefaultOrchestrationConfig.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/DefaultOrchestrationConfig.java index 8d8982b2..9651553d 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/DefaultOrchestrationConfig.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/DefaultOrchestrationConfig.java @@ -3,14 +3,15 @@ import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig; import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig; import io.vavr.control.Option; - import javax.annotation.Nonnull; import lombok.AccessLevel; import lombok.Data; +import lombok.RequiredArgsConstructor; import lombok.Setter; @Data @Setter(AccessLevel.PRIVATE) +@RequiredArgsConstructor(access = AccessLevel.PRIVATE) public class DefaultOrchestrationConfig> implements OrchestrationConfig { @@ -20,45 +21,98 @@ public class DefaultOrchestrationConfig> @Nonnull private Option inputContentFilter = Option.none(); @Nonnull private Option outputContentFilter = Option.none(); - @Nonnull - @Override + @Nonnull private final T wrapper; + @SuppressWarnings("unchecked") - public T instance() { - return (T) this; + private DefaultOrchestrationConfig() { + wrapper = (T) this; + } + + /** + * Create a new instance of {@link DefaultOrchestrationConfig} to delegate to. This is useful when + * exposing the {@link OrchestrationConfig} in other objects, without re-implementing it. To + * maintain fluent API usage, the given wrapper object will be returned by the fluent methods, + * instead of this instance. + * + * @param wrapper The wrapper that delegates to this object. + * @param The type of the wrapper object. + * @return The new instance. + * @see #standalone() + */ + @Nonnull + public static > DefaultOrchestrationConfig asDelegateFor( + @Nonnull final T wrapper) { + return new DefaultOrchestrationConfig<>(wrapper); + } + + /** + * Create an implementation without any object delegating to it. The fluent API will return this + * object itself. + * + * @return The new instance. + * @see #asDelegateFor(OrchestrationConfig) + */ + @Nonnull + public static DefaultOrchestrationConfig standalone() { + return new DefaultOrchestrationConfig<>(); } @Nonnull @Override public T withLlmConfig(@Nonnull final LLMModuleConfig llm) { this.llmConfig = Option.some(llm); - return instance(); + return wrapper; } @Nonnull @Override public T withTemplate(@Nonnull final TemplatingModuleConfig template) { this.template = Option.some(template); - return instance(); + return wrapper; } @Nonnull @Override public T withMaskingConfig(@Nonnull final MaskingConfig maskingConfig) { this.maskingConfig = Option.some(maskingConfig); - return instance(); + return wrapper; } @Nonnull @Override public T withInputContentFilter(@Nonnull final ContentFilter filter) { this.inputContentFilter = Option.some(filter); - return instance(); + return wrapper; } @Nonnull @Override public T withOutputContentFilter(@Nonnull final ContentFilter filter) { this.outputContentFilter = Option.some(filter); - return instance(); + return wrapper; + } + + /** + * Copy the configuration into the given target configuration. The copy is + * shallow and does not override any existing configuration. + * + *

This has two main use cases: + * + *

    + *
  1. Duplicating a config + *
  2. Applying defaults to a config + *
+ * + * @param source The source configuration to copy from. + */ + public DefaultOrchestrationConfig copyFrom(@Nonnull final OrchestrationConfig source) { + llmConfig.orElse(source::getLlmConfig).forEach(this::withLlmConfig); + template.orElse(source::getTemplate).forEach(this::withTemplate); + maskingConfig.orElse(source::getMaskingConfig).forEach(this::withMaskingConfig); + inputContentFilter.orElse(source::getInputContentFilter).forEach(this::withInputContentFilter); + outputContentFilter + .orElse(source::getOutputContentFilter) + .forEach(this::withOutputContentFilter); + return this; } } diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/DpiMaskingConfig.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/DpiMaskingConfig.java index 3b655fac..0485b40f 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/DpiMaskingConfig.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/DpiMaskingConfig.java @@ -6,7 +6,6 @@ import com.sap.ai.sdk.orchestration.client.model.DPIEntities; import com.sap.ai.sdk.orchestration.client.model.DPIEntityConfig; -import com.sap.ai.sdk.orchestration.client.model.MaskingModuleConfig; import com.sap.ai.sdk.orchestration.client.model.MaskingProviderConfig; import java.util.List; import javax.annotation.Nonnull; @@ -33,11 +32,10 @@ public static Builder forPseudonymization() { @Nonnull MaskingProviderConfig toMaskingProviderDTO() { var entities = this.entities.stream().map(it -> DPIEntityConfig.create().type(it)).toList(); - return - MaskingProviderConfig.create() - .type(SAP_DATA_PRIVACY_INTEGRATION) - .method(maskingMethod) - .entities(entities); + return MaskingProviderConfig.create() + .type(SAP_DATA_PRIVACY_INTEGRATION) + .method(maskingMethod) + .entities(entities); } @RequiredArgsConstructor(access = AccessLevel.PRIVATE) diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/ModuleConfigFactory.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/ModuleConfigFactory.java index a7231154..561f9461 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/ModuleConfigFactory.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/ModuleConfigFactory.java @@ -8,19 +8,18 @@ import com.sap.ai.sdk.orchestration.client.model.ModuleConfigs; import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig; import io.vavr.control.Option; - -import javax.annotation.Nonnull; import java.util.ArrayList; import java.util.List; +import javax.annotation.Nonnull; class ModuleConfigFactory { @Nonnull static ModuleConfigs toModuleConfigDTO( - @Nonnull final OrchestrationConfig config, @Nonnull final List messages) { + @Nonnull final OrchestrationConfig config, @Nonnull final List messages) { LLMModuleConfig llm = - config - .getLlmConfig() - .getOrElseThrow(() -> new IllegalStateException("LLM module config is required")); + config + .getLlmConfig() + .getOrElseThrow(() -> new IllegalStateException("LLM module config is required")); /* * Currently, we have to merge the prompt into the template configuration. @@ -35,21 +34,21 @@ static ModuleConfigs toModuleConfigDTO( if (messagesWithPrompt.isEmpty()) { throw new IllegalStateException( - "A prompt is required. Pass at least one message or configure the templating module."); + "A prompt is required. Pass at least one message or configure the templating module."); } TemplatingModuleConfig template = TemplatingModuleConfig.create().template(messagesWithPrompt); maybeTemplate.map(TemplatingModuleConfig::getDefaults).forEach(template::defaults); ModuleConfigs dto = - ModuleConfigs.create().llmModuleConfig(llm).templatingModuleConfig(template); + ModuleConfigs.create().llmModuleConfig(llm).templatingModuleConfig(template); config - .getMaskingConfig() - .filter(DpiMaskingConfig.class::isInstance) - .map(DpiMaskingConfig.class::cast) - .map(DpiMaskingConfig::toMaskingProviderDTO) - .map(it -> MaskingModuleConfig.create().maskingProviders(it)) - .forEach(dto::maskingModuleConfig); + .getMaskingConfig() + .filter(DpiMaskingConfig.class::isInstance) + .map(DpiMaskingConfig.class::cast) + .map(DpiMaskingConfig::toMaskingProviderDTO) + .map(it -> MaskingModuleConfig.create().maskingProviders(it)) + .forEach(dto::maskingModuleConfig); var maybeInputFilter = config.getInputContentFilter(); var maybeOutputFilter = config.getOutputContentFilter(); @@ -57,15 +56,17 @@ static ModuleConfigs toModuleConfigDTO( if (maybeInputFilter.isDefined() || maybeOutputFilter.isDefined()) { var filter = FilteringModuleConfig.create(); maybeInputFilter - .filter(AzureContentFilter.class::isInstance) - .map(AzureContentFilter.class::cast) - .map(AzureContentFilter::toFilterConfigDTO) - .map(it -> FilteringConfig.create().filters(it)).forEach(filter::input); + .filter(AzureContentFilter.class::isInstance) + .map(AzureContentFilter.class::cast) + .map(AzureContentFilter::toFilterConfigDTO) + .map(it -> FilteringConfig.create().filters(it)) + .forEach(filter::input); maybeOutputFilter - .filter(AzureContentFilter.class::isInstance) - .map(AzureContentFilter.class::cast) - .map(AzureContentFilter::toFilterConfigDTO) - .map(it -> FilteringConfig.create().filters(it)).forEach(filter::output); + .filter(AzureContentFilter.class::isInstance) + .map(AzureContentFilter.class::cast) + .map(AzureContentFilter::toFilterConfigDTO) + .map(it -> FilteringConfig.create().filters(it)) + .forEach(filter::output); dto = dto.filteringModuleConfig(filter); } diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java index c09a1b45..4b850ac1 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java @@ -38,7 +38,7 @@ public class OrchestrationClient implements OrchestrationConfig clientConfig = - new DefaultOrchestrationConfig<>(); + DefaultOrchestrationConfig.asDelegateFor(this); @Nonnull private final HttpDestination destination; @@ -47,12 +47,6 @@ public OrchestrationClient() { this.destination = Core.getDestinationForDeployment("db1d64d9f06be467", "default").asHttp(); } - @Nonnull - @Override - public OrchestrationClient instance() { - return this; - } - /** * Generate a completion for the given user prompt. * @@ -61,11 +55,15 @@ public OrchestrationClient instance() { * @throws OrchestrationClientException if the request fails */ @Nonnull - // TODO: decide if we want to offer this in addition to the already simple `new Prompt("asdf")` public String chatCompletion(@Nonnull final String userPrompt) throws OrchestrationClientException { var response = chatCompletion(new OrchestrationPrompt(userPrompt)); - return response.getOrchestrationResult().getChoices().get(0).getMessage().getContent(); + + var choice = response.getOrchestrationResult().getChoices().get(0); + if (choice.getFinishReason().equalsIgnoreCase("content_filter")) { + throw new OrchestrationClientException("Output content filter triggered"); + } + return choice.getMessage().getContent(); } /** @@ -78,14 +76,7 @@ public String chatCompletion(@Nonnull final String userPrompt) @Nonnull public CompletionPostResponse chatCompletion(@Nonnull final OrchestrationPrompt prompt) throws OrchestrationClientException { - var moduleConfigsDto = prompt.toModuleConfigDTO(clientConfig); - var dto = - CompletionPostRequest.create() - .orchestrationConfig( - com.sap.ai.sdk.orchestration.client.model.OrchestrationConfig.create() - .moduleConfigurations(moduleConfigsDto)) - .messagesHistory(prompt.getMessages()) - .inputParams(prompt.getTemplateParameters()); + var dto = prompt.toCompletionPostRequestDTO(clientConfig); return executeRequest(dto); } @@ -106,6 +97,7 @@ public Stream streamChatCompletionDelta(@Nonnull final OrchestrationProm @Nonnull protected CompletionPostResponse executeRequest(@Nonnull final CompletionPostRequest request) { final var client = ApacheHttpClient5Accessor.getHttpClient(destination); + // TODO: update after AiCoreService refactoring final BasicClassicHttpRequest postRequest = new HttpPost("/v2/inference/deployments/db1d64d9f06be467/completion"); try { diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationConfig.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationConfig.java index 07760b89..4064b2c2 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationConfig.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationConfig.java @@ -12,8 +12,6 @@ * @param Type of the specific implementation to make a fluent API possible. */ public interface OrchestrationConfig> { - @Nonnull - T instance(); @Nonnull Option getLlmConfig(); @@ -44,30 +42,4 @@ public interface OrchestrationConfig> { @Nonnull T withOutputContentFilter(@Nonnull final ContentFilter filter); - - /** - * Copy the configuration into the given target configuration. The copy is - * shallow and does not override any existing configuration. - * - *

This has two main use cases: - * - *

    - *
  1. Duplicating a config - *
  2. Applying defaults to a config - *
- * - * @param source The source configuration to copy from. - */ - default T copyOrchestrationConfigurationFrom(@Nonnull final OrchestrationConfig source) { - getLlmConfig().orElse(source::getLlmConfig).forEach(this::withLlmConfig); - getTemplate().orElse(source::getTemplate).forEach(this::withTemplate); - getMaskingConfig().orElse(source::getMaskingConfig).forEach(this::withMaskingConfig); - getInputContentFilter() - .orElse(source::getInputContentFilter) - .forEach(this::withInputContentFilter); - getOutputContentFilter() - .orElse(source::getOutputContentFilter) - .forEach(this::withOutputContentFilter); - return instance(); - } } diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationPrompt.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationPrompt.java index 88bc91b0..e0bfd2c9 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationPrompt.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationPrompt.java @@ -1,7 +1,7 @@ package com.sap.ai.sdk.orchestration; import com.sap.ai.sdk.orchestration.client.model.ChatMessage; -import com.sap.ai.sdk.orchestration.client.model.ModuleConfigs; +import com.sap.ai.sdk.orchestration.client.model.CompletionPostRequest; import java.util.List; import java.util.Map; import javax.annotation.Nonnull; @@ -20,13 +20,8 @@ public class OrchestrationPrompt implements OrchestrationConfig delegate = new DefaultOrchestrationConfig<>(); - - @Nonnull - @Override - public OrchestrationPrompt instance() { - return this; - } + DefaultOrchestrationConfig delegate = + DefaultOrchestrationConfig.asDelegateFor(this); public OrchestrationPrompt(@Nonnull final String message) { this(List.of(ChatMessage.create().role("user").content(message)), Map.of()); @@ -41,9 +36,15 @@ public OrchestrationPrompt(@Nonnull final Map inputParams) { } @Nonnull - ModuleConfigs toModuleConfigDTO(@Nonnull final OrchestrationConfig defaults) { - // duplicate the prompt config so it isn't modified, to make sure this prompt can be reused - var config = new DefaultOrchestrationConfig<>().copyOrchestrationConfigurationFrom(this).copyOrchestrationConfigurationFrom(defaults); - return ModuleConfigFactory.toModuleConfigDTO(config, messages); + CompletionPostRequest toCompletionPostRequestDTO(@Nonnull final OrchestrationConfig defaults) { + // duplicate the prompt config, then apply the defaults to the copy + // that way this prompt remains unchanged and can be reused + var config = DefaultOrchestrationConfig.standalone().copyFrom(this).copyFrom(defaults); + var moduleConfigDTO = ModuleConfigFactory.toModuleConfigDTO(config, messages); + return CompletionPostRequest.create() + .orchestrationConfig( + com.sap.ai.sdk.orchestration.client.model.OrchestrationConfig.create() + .moduleConfigurations(moduleConfigDTO)) + .inputParams(templateParameters); } } diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatOptions.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatOptions.java index 981513b4..59d3d97e 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatOptions.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatOptions.java @@ -19,19 +19,13 @@ /** Configuration to be used for orchestration requests. */ @Data -@Getter(AccessLevel.PACKAGE) -@Setter(AccessLevel.PRIVATE) +@Getter(AccessLevel.NONE) +@Setter(AccessLevel.NONE) public class OrchestrationChatOptions implements ChatOptions, OrchestrationConfig { @Delegate @Nonnull private final DefaultOrchestrationConfig delegate = - new DefaultOrchestrationConfig<>(); - - @Nonnull - @Override - public OrchestrationChatOptions instance() { - return this; - } + DefaultOrchestrationConfig.asDelegateFor(this); @Getter(AccessLevel.PUBLIC) @Nonnull @@ -114,9 +108,9 @@ public Double getTopP() { @Override public OrchestrationChatOptions copy() { var copy = new OrchestrationChatOptions(); - - // TODO: implement, needed for chat memory apparently - throw new RuntimeException("Not implemented"); + copy.delegate.copyFrom(this.delegate); + copy.templateParameters.putAll(this.templateParameters); + return copy; } @SuppressWarnings("unchecked") diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatResponse.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatResponse.java index 36da9147..1a58bd4b 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatResponse.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatResponse.java @@ -40,7 +40,7 @@ static OrchestrationChatResponse fromOrchestrationResponse( } @Nonnull - private static List toGenerations(@Nonnull final LLMModuleResult result) { + static List toGenerations(@Nonnull final LLMModuleResult result) { return result.getChoices().stream() .map(OrchestrationChatResponse::toAssistantMessage) .map(Generation::new) @@ -48,7 +48,7 @@ private static List toGenerations(@Nonnull final LLMModuleResult res } @Nonnull - private static AssistantMessage toAssistantMessage(@Nonnull final LLMChoice choice) { + static AssistantMessage toAssistantMessage(@Nonnull final LLMChoice choice) { Map metadata = new HashMap<>(); metadata.put("finish_reason", choice.getFinishReason()); metadata.put("index", choice.getIndex()); @@ -59,7 +59,7 @@ private static AssistantMessage toAssistantMessage(@Nonnull final LLMChoice choi } @Nonnull - private static ChatResponseMetadata toChatResponseMetadata( + static ChatResponseMetadata toChatResponseMetadata( @Nonnull final LLMModuleResult orchestrationResult) { var metadataBuilder = ChatResponseMetadata.builder(); diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/ModuleConfigFactoryTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/ModuleConfigFactoryTest.java index 74dc7ce9..21646975 100644 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/ModuleConfigFactoryTest.java +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/ModuleConfigFactoryTest.java @@ -1,44 +1,35 @@ package com.sap.ai.sdk.orchestration; +import static com.sap.ai.sdk.orchestration.AzureContentFilter.Setting.VERY_STRICT; +import static com.sap.ai.sdk.orchestration.ModuleConfigFactory.toModuleConfigDTO; +import static com.sap.ai.sdk.orchestration.client.model.FilterConfig.TypeEnum.AZURE_CONTENT_SAFETY; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + import com.sap.ai.sdk.orchestration.client.model.ChatMessage; import com.sap.ai.sdk.orchestration.client.model.DPIEntities; import com.sap.ai.sdk.orchestration.client.model.DPIEntityConfig; import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig; import com.sap.ai.sdk.orchestration.client.model.MaskingProviderConfig; import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig; +import java.util.List; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import java.util.List; - -import static com.sap.ai.sdk.orchestration.AzureContentFilter.Setting.VERY_STRICT; -import static com.sap.ai.sdk.orchestration.ModuleConfigFactory.toModuleConfigDTO; -import static com.sap.ai.sdk.orchestration.client.model.FilterConfig.TypeEnum.AZURE_CONTENT_SAFETY; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.junit.jupiter.api.Assertions.*; -import static org.mockito.Mockito.mock; - class ModuleConfigFactoryTest { private static final List messages = List.of(mock(ChatMessage.class)); private DefaultOrchestrationConfig config; @BeforeEach void setUp() { - config = new DefaultOrchestrationConfig<>(); + config = DefaultOrchestrationConfig.standalone(); config.withLlmConfig(mock(LLMModuleConfig.class)); } - @Test - void testInstance() { - OrchestrationConfig instance = config.instance(); - - assertSame(config, instance); - } - @Test void testThrowsOnMissingConfig() { - config = new DefaultOrchestrationConfig<>(); + config = DefaultOrchestrationConfig.standalone(); assertThatThrownBy(() -> toModuleConfigDTO(config, messages)) .isInstanceOf(IllegalStateException.class) @@ -50,6 +41,16 @@ void testThrowsOnMissingConfig() { .hasMessageContaining("prompt is required"); } + @Test + void testLlmConfig() { + var llmConfig = mock(LLMModuleConfig.class); + + config.withLlmConfig(llmConfig); + + var result = toModuleConfigDTO(config, messages).getLlmModuleConfig(); + assertThat(result).isSameAs(llmConfig); + } + @Test void testTemplateIsCreatedFromMessages() { var result = toModuleConfigDTO(config, messages).getTemplatingModuleConfig(); diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationClientTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationClientTest.java new file mode 100644 index 00000000..f2267d74 --- /dev/null +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationClientTest.java @@ -0,0 +1,81 @@ +package com.sap.ai.sdk.orchestration; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.sap.ai.sdk.orchestration.client.model.ChatMessage; +import com.sap.ai.sdk.orchestration.client.model.CompletionPostRequest; +import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse; +import com.sap.ai.sdk.orchestration.client.model.LLMChoice; +import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig; +import com.sap.ai.sdk.orchestration.client.model.LLMModuleResult; +import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination; + +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentMatchers; + +/** + * Test that queries are on the right URL, with the right headers. Also check that the received + * response is parsed correctly in the generated client. + */ +public class OrchestrationClientTest { + private OrchestrationClient client; + + private static final LLMModuleConfig LLM_CONFIG = + LLMModuleConfig.create().modelName("gpt-35-turbo-16k").modelParams(Map.of()); + + @BeforeEach + void setup() { + var destination = DefaultHttpDestination.builder("").build(); + client = spy(new OrchestrationClient(destination).withLlmConfig(LLM_CONFIG)); + } + + @Test + void testSimpleChatCompletion() { + stubResponse("stop"); + + var result = client.chatCompletion("Hello there!"); + assertThat(result).isEqualTo("General Kenobi!"); + + var expected = ChatMessage.create().role("user").content("Hello there!"); + verify(client).chatCompletion( + ArgumentMatchers.argThat( + prompt -> prompt.getMessages().contains(expected))); + } + + @Test + void testSimpleChatCompletionThrowsOnOutputContentFilter() { + + stubResponse("content_filter"); + + assertThatThrownBy(() -> client.chatCompletion("foo")) + .isInstanceOf(OrchestrationClientException.class) + .hasMessageContaining("content filter"); + } + + private void stubResponse(String finishReason) { + var response = mock(CompletionPostResponse.class); + var orchestrationResult = mock(LLMModuleResult.class); + var llmChoice = + LLMChoice.create() + .index(0) + .message(ChatMessage.create().role("assistant").content("General Kenobi!")) + .finishReason(finishReason); + + when(orchestrationResult.getChoices()).thenReturn(List.of(llmChoice)); + when(response.getOrchestrationResult()).thenReturn(orchestrationResult); + + doReturn(response).when(client).executeRequest(any()); + } +} diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationResponseHandlerTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationResponseHandlerTest.java new file mode 100644 index 00000000..f77b26de --- /dev/null +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationResponseHandlerTest.java @@ -0,0 +1,82 @@ +package com.sap.ai.sdk.orchestration; + +import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl; +import static com.github.tomakehurst.wiremock.client.WireMock.badRequest; +import static com.github.tomakehurst.wiremock.client.WireMock.ok; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.serverError; +import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; +import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig; +import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor; +import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Cache; +import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination; +import java.util.Map; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** + * Test that queries are on the right URL, with the right headers. Also check that the received + * response is parsed correctly in the generated client. + */ +@SuppressWarnings("UnstableApiUsage") +@WireMockTest +public class OrchestrationResponseHandlerTest { + private OrchestrationClient client; + + private static final LLMModuleConfig LLM_CONFIG = + LLMModuleConfig.create().modelName("gpt-35-turbo-16k").modelParams(Map.of()); + + @BeforeEach + void setup(WireMockRuntimeInfo server) { + var destination = DefaultHttpDestination.builder(server.getHttpBaseUrl()).build(); + client = new OrchestrationClient(destination).withLlmConfig(LLM_CONFIG); + ApacheHttpClient5Accessor.setHttpClientCache(ApacheHttpClient5Cache.DISABLED); + } + + @AfterEach + void reset() { + ApacheHttpClient5Accessor.setHttpClientCache(null); + ApacheHttpClient5Accessor.setHttpClientFactory(null); + } + + @Test + void testSuccessResponse() { + var response = + ok().withBodyFile("serializedResponse.json").withHeader("Content-Type", "application/json"); + stubFor(post(anyUrl()).willReturn(response)); + + var result = client.chatCompletion("Hello there!"); + + assertThat(result).isEqualTo("General Kenobi!"); + } + + @Test + void testGenericErrorHandling() { + stubFor(post(anyUrl()).willReturn(serverError())); + + assertThatThrownBy(() -> client.chatCompletion("Hello World!")) + .isInstanceOf(OrchestrationClientException.class) + .hasMessageContaining("500 Server Error"); + } + + @Test + void testOrchestrationErrorParsing() { + stubFor( + post(anyUrl()) + .willReturn( + badRequest() + .withHeader("Content-Type", "application/json") + .withBodyFile("errorResponse.json"))); + + assertThatThrownBy(() -> client.chatCompletion("Hello World!")) + .isInstanceOf(OrchestrationClientException.class) + .hasMessageContaining("400 Bad Request") + .hasMessageContaining("'orchestration_config' is a required property"); + } +} diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/SerializationTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/SerializationTest.java new file mode 100644 index 00000000..65c22ced --- /dev/null +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/SerializationTest.java @@ -0,0 +1,138 @@ +package com.sap.ai.sdk.orchestration; + +import static com.sap.ai.sdk.orchestration.AzureContentFilter.Setting.LENIENT; +import static com.sap.ai.sdk.orchestration.AzureContentFilter.Setting.MODERATE; +import static com.sap.ai.sdk.orchestration.AzureContentFilter.Setting.STRICT; +import static com.sap.ai.sdk.orchestration.AzureContentFilter.Setting.VERY_STRICT; +import static org.assertj.core.api.Assertions.assertThat; + +import com.fasterxml.jackson.databind.JsonNode; +import com.sap.ai.sdk.orchestration.client.model.ChatMessage; +import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse; +import com.sap.ai.sdk.orchestration.client.model.DPIEntities; +import com.sap.ai.sdk.orchestration.client.model.GenericModuleResult; +import com.sap.ai.sdk.orchestration.client.model.LLMChoice; +import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig; +import com.sap.ai.sdk.orchestration.client.model.LLMModuleResult; +import com.sap.ai.sdk.orchestration.client.model.ModuleResults; +import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig; +import com.sap.ai.sdk.orchestration.client.model.TokenUsage; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; + +class SerializationTest { + + @Test + void testSerialization() throws IOException { + var llm = + LLMModuleConfig.create() + .modelName("gpt-35-turbo-16k") + .modelParams(Map.of("temperature", 0.5, "frequency_penalty", 1)); + var template = + TemplatingModuleConfig.create() + .template(List.of(ChatMessage.create().role("user").content("{{?input}}"))) + .defaults(Map.of("input", "Hello World!")); + + var inputFilter = new AzureContentFilter().selfHarm(LENIENT); + var outputFilter = + new AzureContentFilter() + .hate(VERY_STRICT) + .selfHarm(STRICT) + .sexual(MODERATE) + .violence(LENIENT); + + var masking = + DpiMaskingConfig.forAnonymization() + .withEntities(DPIEntities.ADDRESS, DPIEntities.IBAN, DPIEntities.LOCATION); + var inputParams = Map.of("input", "Reply with 'Orchestration Service is working!' in German"); + + var dto = + new OrchestrationPrompt(inputParams) + .withLlmConfig(llm) + .withTemplate(template) + .withInputContentFilter(inputFilter) + .withOutputContentFilter(outputFilter) + .withMaskingConfig(masking) + .toCompletionPostRequestDTO(DefaultOrchestrationConfig.standalone()); + + var actual = OrchestrationClient.JACKSON.valueToTree(dto); + + var expected = + OrchestrationClient.JACKSON.readValue( + getClass().getClassLoader().getResource("serializedRequest.json"), JsonNode.class); + + assertThat(actual) + .withRepresentation(it -> ((JsonNode) it).toPrettyString()) + .isEqualTo(expected); + } + + @Test + void testDeserialization() throws IOException { + + var llmResult = + LLMModuleResult.create() + .id("chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj") + ._object("chat.completion") + .created(1721224505) + .model("gpt-35-turbo-16k") + .choices( + List.of( + LLMChoice.create() + .index(0) + .message(ChatMessage.create().role("assistant").content("General Kenobi!")) + .finishReason("stop"))) + .usage(TokenUsage.create().completionTokens(7).promptTokens(19).totalTokens(26)); + + var orchestrationResult = + LLMModuleResult.create() + .id("chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj") + ._object("chat.completion") + .created(1721224505) + .model("gpt-35-turbo-16k") + .choices( + List.of( + LLMChoice.create() + .index(0) + .message(ChatMessage.create().role("assistant").content("General Kenobi!")) + .finishReason("stop"))) + .usage(TokenUsage.create().completionTokens(7).promptTokens(19).totalTokens(26)); + + var inputFilterResult = + GenericModuleResult.create() + .message("Input filter passed successfully.") + .data( + Map.of( + "original_service_response", + Map.of("Hate", 0, "SelfHarm", 0, "Sexual", 0, "Violence", 2), + "checked_text", "Hello there!")); + var outputFilterResult = + GenericModuleResult.create() + .message("Output filter passed successfully.") + .data( + Map.of( + "original_service_response", + Map.of("Hate", 0, "SelfHarm", 0, "Sexual", 0, "Violence", 2), + "checked_text", "General Kenobi!")); + var expected = + CompletionPostResponse.create() + .requestId("26ea36b5-c196-4806-a9a6-a686f0c6ad91") + .moduleResults( + ModuleResults.create() + .templating(List.of(ChatMessage.create().role("user").content("Hello there!"))) + .llm(llmResult) + .inputFiltering(inputFilterResult) + .outputFiltering(outputFilterResult) + .inputMasking(null) + .outputUnmasking(List.of())) + .orchestrationResult(orchestrationResult); + + var actual = + OrchestrationClient.JACKSON.readValue( + getClass().getClassLoader().getResourceAsStream("serializedResponse.json"), + CompletionPostResponse.class); + + assertThat(actual).isEqualTo(expected); + } +} diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/client/OrchestrationUnitTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/client/OrchestrationUnitTest.java deleted file mode 100644 index abc7b37c..00000000 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/client/OrchestrationUnitTest.java +++ /dev/null @@ -1,193 +0,0 @@ -package com.sap.ai.sdk.orchestration.client; - -import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson; -import static com.github.tomakehurst.wiremock.client.WireMock.jsonResponse; -import static com.github.tomakehurst.wiremock.client.WireMock.okJson; -import static com.github.tomakehurst.wiremock.client.WireMock.post; -import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor; -import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; -import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; -import static com.github.tomakehurst.wiremock.client.WireMock.verify; -import static com.sap.ai.sdk.orchestration.AzureContentFilter.Setting.LENIENT; -import static com.sap.ai.sdk.orchestration.AzureContentFilter.Setting.MODERATE; -import static com.sap.ai.sdk.orchestration.AzureContentFilter.Setting.STRICT; -import static com.sap.ai.sdk.orchestration.AzureContentFilter.Setting.VERY_STRICT; -import static org.apache.hc.core5.http.HttpStatus.SC_BAD_REQUEST; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; -import com.github.tomakehurst.wiremock.junit5.WireMockTest; -import com.sap.ai.sdk.orchestration.AzureContentFilter; -import com.sap.ai.sdk.orchestration.OrchestrationClient; -import com.sap.ai.sdk.orchestration.OrchestrationClientException; -import com.sap.ai.sdk.orchestration.OrchestrationPrompt; -import com.sap.ai.sdk.orchestration.client.model.ChatMessage; -import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig; -import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig; -import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination; -import java.io.IOException; -import java.util.List; -import java.util.Map; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -/** - * Test that queries are on the right URL, with the right headers. Also check that the received - * response is parsed correctly in the generated client. - */ -@WireMockTest -public class OrchestrationUnitTest { - private OrchestrationClient client; - - private static final LLMModuleConfig LLM_CONFIG = - LLMModuleConfig.create() - .modelName("gpt-35-turbo-16k") - .modelParams( - Map.of( - "max_tokens", 50, - "temperature", 0.1, - "frequency_penalty", 0, - "presence_penalty", 0)); - - private static final TemplatingModuleConfig TEMPLATE_CONFIG = - TemplatingModuleConfig.create() - .template(List.of(ChatMessage.create().role("user").content("{{?input}}"))); - - @BeforeEach - void setup(WireMockRuntimeInfo server) { - final DefaultHttpDestination destination = - DefaultHttpDestination.builder(server.getHttpBaseUrl()).build(); - client = new OrchestrationClient(destination).withLlmConfig(LLM_CONFIG); - } - - @Test - void testTemplating() throws IOException { - final String response = - new String( - getClass() - .getClassLoader() - .getResourceAsStream("templatingResponse.json") - .readAllBytes()); - stubFor(post(urlPathEqualTo("/completion")).willReturn(okJson(response))); - - final var inputParams = - Map.of("input", "Reply with 'Orchestration Service is working!' in German"); - - var msg = ChatMessage.create().role("user").content("{{?input}}"); - var prompt = - new OrchestrationPrompt(inputParams) - .withTemplate(TemplatingModuleConfig.create().template(List.of(msg))); - - final var result = client.chatCompletion(prompt); - - assertThat(result.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91"); - assertThat(result.getModuleResults().getTemplating().get(0).getContent()) - .isEqualTo("Reply with 'Orchestration Service is working!' in German"); - assertThat(result.getModuleResults().getTemplating().get(0).getRole()).isEqualTo("user"); - var llm = result.getModuleResults().getLlm(); - assertThat(llm.getId()).isEqualTo("chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj"); - assertThat(llm.getObject()).isEqualTo("chat.completion"); - assertThat(llm.getCreated()).isEqualTo(1721224505); - assertThat(llm.getModel()).isEqualTo("gpt-35-turbo-16k"); - var choices = llm.getChoices(); - assertThat(choices.get(0).getIndex()).isEqualTo(0); - assertThat(choices.get(0).getMessage().getContent()) - .isEqualTo("Orchestration Service funktioniert!"); - assertThat(choices.get(0).getMessage().getRole()).isEqualTo("assistant"); - assertThat(choices.get(0).getFinishReason()).isEqualTo("stop"); - var usage = llm.getUsage(); - assertThat(usage.getCompletionTokens()).isEqualTo(7); - assertThat(usage.getPromptTokens()).isEqualTo(19); - assertThat(usage.getTotalTokens()).isEqualTo(26); - assertThat(result.getOrchestrationResult().getId()) - .isEqualTo("chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj"); - assertThat(result.getOrchestrationResult().getObject()).isEqualTo("chat.completion"); - assertThat(result.getOrchestrationResult().getCreated()).isEqualTo(1721224505); - assertThat(result.getOrchestrationResult().getModel()).isEqualTo("gpt-35-turbo-16k"); - choices = result.getOrchestrationResult().getChoices(); - assertThat(choices.get(0).getIndex()).isEqualTo(0); - assertThat(choices.get(0).getMessage().getContent()) - .isEqualTo("Orchestration Service funktioniert!"); - assertThat(choices.get(0).getMessage().getRole()).isEqualTo("assistant"); - assertThat(choices.get(0).getFinishReason()).isEqualTo("stop"); - usage = result.getOrchestrationResult().getUsage(); - assertThat(usage.getCompletionTokens()).isEqualTo(7); - assertThat(usage.getPromptTokens()).isEqualTo(19); - assertThat(usage.getTotalTokens()).isEqualTo(26); - - // verify that null fields are absent from the sent request - final String request = - new String( - getClass() - .getClassLoader() - .getResourceAsStream("templatingRequest.json") - .readAllBytes()); - verify(postRequestedFor(urlPathEqualTo("/completion")).withRequestBody(equalToJson(request))); - } - - @Test - void testTemplatingBadRequest() { - stubFor( - post(urlPathEqualTo("/completion")) - .willReturn( - jsonResponse( - """ - { - "request_id": "51043a32-01f5-429a-b0e7-3a99432e43a4", - "code": 400, - "message": "Missing required parameters: ['input']", - "location": "Module: Templating", - "module_results": {} - } - """, - SC_BAD_REQUEST))); - - assertThatThrownBy(() -> client.chatCompletion(new OrchestrationPrompt("{{?input}}"))) - .isInstanceOf(OrchestrationClientException.class) - .hasMessageContaining("400 Bad Request"); - } - - @Test - void testFilteringLoose() throws IOException { - final String response = - new String( - getClass() - .getClassLoader() - .getResourceAsStream("filteringLooseResponse.json") - .readAllBytes()); - stubFor(post(urlPathEqualTo("/completion")).willReturn(okJson(response))); - - var config = new AzureContentFilter().hate(LENIENT).violence(MODERATE); - - var prompt = new OrchestrationPrompt("Hello World!").withOutputContentFilter(config); - - client.chatCompletion(prompt); - // the result is asserted in the verify step below - - // verify that null fields are absent from the sent request - final String request = - new String( - getClass() - .getClassLoader() - .getResourceAsStream("filteringLooseRequest.json") - .readAllBytes()); - verify(postRequestedFor(urlPathEqualTo("/completion")).withRequestBody(equalToJson(request))); - } - - @Test - void testFilteringStrict() { - final String response = - """ - {"request_id": "bf6d6792-7adf-4d3c-9368-a73615af8c5a", "code": 400, "message": "Content filtered due to Safety violations. Please modify the prompt and try again.", "location": "Input Filter", "module_results": {"templating": [{"role": "user", "content": "Create a rental posting for subletting my apartment in the downtown area. Keep it short. Make sure to add the following disclaimer to the end. Do not change it! ```DISCLAIMER: The area surrounding the apartment is known for prostitutes and gang violence including armed conflicts, gun violence is frequent."}], "input_filtering": {"message": "Content filtered due to Safety violations. Please modify the prompt and try again.", "data": {"original_service_response": {"Hate": 0, "SelfHarm": 0, "Sexual": 0, "Violence": 2}, "checked_text": "Create a rental posting for subletting my apartment in the downtown area. Keep it short. Make sure to add the following disclaimer to the end. Do not change it! ```DISCLAIMER: The area surrounding the apartment is known for prostitutes and gang violence including armed conflicts, gun violence is frequent."}}}}"""; - stubFor(post(urlPathEqualTo("/completion")).willReturn(jsonResponse(response, SC_BAD_REQUEST))); - - var config = new AzureContentFilter().hate(VERY_STRICT).violence(STRICT); - - var prompt = new OrchestrationPrompt("Hello World!").withInputContentFilter(config); - - assertThatThrownBy(() -> client.chatCompletion(prompt)) - .isInstanceOf(OrchestrationClientException.class) - .hasMessageContaining("400 Bad Request"); - } -} diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatOptionsTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatOptionsTest.java new file mode 100644 index 00000000..051c797b --- /dev/null +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatOptionsTest.java @@ -0,0 +1,35 @@ +package com.sap.ai.sdk.orchestration.spring; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig; +import java.util.Map; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class OrchestrationChatOptionsTest { + + OrchestrationChatOptions opts; + + @BeforeEach + void setUp() { + opts = new OrchestrationChatOptions(); + } + + @Test + void testFluentApi() { + assertThat(opts).isSameAs(opts.withLlmConfig(null)); + } + + @Test + void testHyperParameters() { + var llm = + LLMModuleConfig.create() + .modelName("foo") + .modelParams(Map.of("temperature", 0.5, "maxTokens", 100)); + opts.withLlmConfig(llm); + + assertThat(opts.getTemperature()).isEqualTo(0.5); + assertThat(opts.getMaxTokens()).isEqualTo(100); + } +} diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatResponseTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatResponseTest.java new file mode 100644 index 00000000..5a188ca5 --- /dev/null +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatResponseTest.java @@ -0,0 +1,54 @@ +package com.sap.ai.sdk.orchestration.spring; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.sap.ai.sdk.orchestration.client.model.ChatMessage; +import com.sap.ai.sdk.orchestration.client.model.LLMChoice; +import com.sap.ai.sdk.orchestration.client.model.LLMModuleResult; +import com.sap.ai.sdk.orchestration.client.model.TokenUsage; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.AssistantMessage; + +class OrchestrationChatResponseTest { + + @Test + void testToAssistantMessage() { + var choice = + LLMChoice.create() + .index(0) + .message(ChatMessage.create().role("assistant").content("Hello, world!")) + .finishReason("stop"); + + AssistantMessage message = OrchestrationChatResponse.toAssistantMessage(choice); + + assertThat(message.getContent()).isEqualTo("Hello, world!"); + assertThat(message.getMetadata()).containsEntry("finish_reason", "stop"); + assertThat(message.getMetadata()).containsEntry("index", 0); + } + + @Test + void testToChatResponseMetadata() { + var moduleResult = + LLMModuleResult.create() + .id("test-id") + ._object("test-object") + .created(123456789) + .model("test-model") + .choices(List.of()) + .usage(TokenUsage.create().completionTokens(20).promptTokens(10).totalTokens(30)); + + var metadata = OrchestrationChatResponse.toChatResponseMetadata(moduleResult); + + assertThat(metadata.getId()).isEqualTo("test-id"); + assertThat(metadata.getModel()).isEqualTo("test-model"); + assertThat(metadata.get("object")).isEqualTo("test-object"); + assertThat(metadata.get("created")).isEqualTo(123456789); + + var usage = metadata.getUsage(); + + assertThat(usage.getPromptTokens()).isEqualTo(10L); + assertThat(usage.getGenerationTokens()).isEqualTo(20L); + assertThat(usage.getTotalTokens()).isEqualTo(30L); + } +} diff --git a/orchestration/src/test/resources/errorResponse.json b/orchestration/src/test/resources/errorResponse.json new file mode 100644 index 00000000..6964cbc1 --- /dev/null +++ b/orchestration/src/test/resources/errorResponse.json @@ -0,0 +1,7 @@ +{ + "request_id": "59468e72-7309-4299-b988-bf3bbea461f8", + "code": 400, + "message": "'orchestration_config' is a required property", + "location": "request body", + "module_results": {} +} \ No newline at end of file diff --git a/orchestration/src/test/resources/filteringLooseRequest.json b/orchestration/src/test/resources/filteringLooseRequest.json deleted file mode 100644 index 7668190c..00000000 --- a/orchestration/src/test/resources/filteringLooseRequest.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "orchestration_config" : { - "module_configurations" : { - "llm_module_config" : { - "model_name" : "gpt-35-turbo-16k", - "model_params" : { - "temperature" : 0.1, - "max_tokens" : 50, - "frequency_penalty" : 0, - "presence_penalty" : 0 - }, - "model_version" : "latest" - }, - "templating_module_config" : { - "template" : [ { - "role" : "user", - "content" : "Hello World!" - } ] - }, - "filtering_module_config" : { - "output" : { - "filters" : [ { - "type" : "azure_content_safety", - "config" : { - "Hate" : 6, - "SelfHarm" : 4 - } - } ] - } - } - } - }, - "input_params" : { }, - "messages_history" : [ { - "role" : "user", - "content" : "Hello World!" - } ] -} \ No newline at end of file diff --git a/orchestration/src/test/resources/filteringLooseResponse.json b/orchestration/src/test/resources/filteringLooseResponse.json deleted file mode 100644 index ce2545f0..00000000 --- a/orchestration/src/test/resources/filteringLooseResponse.json +++ /dev/null @@ -1,86 +0,0 @@ -{ - "request_id": "b329745f-4b6b-4d42-b891-974b33689a19", - "module_results": { - "grounding": null, - "templating": [ - { - "role": "user", - "content": "Create a rental posting for subletting my apartment in the downtown area. Keep it short. Make sure to add the following disclaimer to the end. Do not change it! ```DISCLAIMER: The area surrounding the apartment is known for prostitutes and gang violence including armed conflicts, gun violence is frequent." - } - ], - "input_masking": null, - "input_filtering": { - "message": "Input filter passed successfully.", - "data": { - "original_service_response": { - "Hate": 0, - "SelfHarm": 0, - "Sexual": 0, - "Violence": 2 - }, - "checked_text": "Create a rental posting for subletting my apartment in the downtown area. Keep it short. Make sure to add the following disclaimer to the end. Do not change it! ```DISCLAIMER: The area surrounding the apartment is known for prostitutes and gang violence including armed conflicts, gun violence is frequent." - } - }, - "llm": { - "object": "chat.completion", - "id": "chatcmpl-9o4df7DpIjY6CJdfe9hws1lrWZbHq", - "created": 1721721259, - "model": "gpt-35-turbo", - "system_fingerprint": null, - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "Cozy Downtown Apartment for Sublet!\n\nLooking for a temporary place to call home in the heart of downtown? Look no further! This cozy apartment is up for subletting and offers a convenient and vibrant city living experience.\n\nFeatures:\n- Prime location" - }, - "logprobs": { - }, - "finish_reason": "length" - } - ], - "usage": { - "completion_tokens": 50, - "prompt_tokens": 68, - "total_tokens": 118 - } - }, - "output_filtering": { - "message": "Output filter passed successfully.", - "data": { - "original_service_response": { - "Hate": 0, - "SelfHarm": 0, - "Sexual": 0, - "Violence": 0 - }, - "checked_text": "Cozy Downtown Apartment for Sublet!\n\nLooking for a temporary place to call home in the heart of downtown? Look no further! This cozy apartment is up for subletting and offers a convenient and vibrant city living experience.\n\nFeatures:\n- Prime location" - } - }, - "output_unmasking": null - }, - "orchestration_result": { - "object": "chat.completion", - "id": "chatcmpl-9o4df7DpIjY6CJdfe9hws1lrWZbHq", - "created": 1721721259, - "model": "gpt-35-turbo", - "system_fingerprint": null, - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "Cozy Downtown Apartment for Sublet!\n\nLooking for a temporary place to call home in the heart of downtown? Look no further! This cozy apartment is up for subletting and offers a convenient and vibrant city living experience.\n\nFeatures:\n- Prime location" - }, - "logprobs": { - }, - "finish_reason": "length" - } - ], - "usage": { - "completion_tokens": 50, - "prompt_tokens": 68, - "total_tokens": 118 - } - } -} diff --git a/orchestration/src/test/resources/serializedRequest.json b/orchestration/src/test/resources/serializedRequest.json new file mode 100644 index 00000000..a4101d65 --- /dev/null +++ b/orchestration/src/test/resources/serializedRequest.json @@ -0,0 +1,60 @@ +{ + "orchestration_config" : { + "module_configurations" : { + "llm_module_config" : { + "model_name" : "gpt-35-turbo-16k", + "model_params" : { + "frequency_penalty" : 1, + "temperature" : 0.5 + }, + "model_version" : "latest" + }, + "templating_module_config" : { + "template" : [ { + "role" : "user", + "content" : "{{?input}}" + } ], + "defaults" : { + "input" : "Hello World!" + } + }, + "filtering_module_config" : { + "input" : { + "filters" : [ { + "type" : "azure_content_safety", + "config" : { + "SelfHarm" : 6 + } + } ] + }, + "output" : { + "filters" : [ { + "type" : "azure_content_safety", + "config" : { + "Hate" : 0, + "SelfHarm" : 2, + "Sexual": 4, + "Violence": 6 + } + } ] + } + }, + "masking_module_config" : { + "masking_providers" : [ { + "type" : "sap_data_privacy_integration", + "method" : "anonymization", + "entities" : [ { + "type" : "profile-address" + }, { + "type" : "profile-iban" + }, { + "type" : "profile-location" + } ] + } ] + } + } + }, + "input_params" : { + "input" : "Reply with 'Orchestration Service is working!' in German" + } +} \ No newline at end of file diff --git a/orchestration/src/test/resources/templatingResponse.json b/orchestration/src/test/resources/serializedResponse.json similarity index 59% rename from orchestration/src/test/resources/templatingResponse.json rename to orchestration/src/test/resources/serializedResponse.json index 3fcc0a30..d471d2f5 100644 --- a/orchestration/src/test/resources/templatingResponse.json +++ b/orchestration/src/test/resources/serializedResponse.json @@ -4,7 +4,7 @@ "templating": [ { "role": "user", - "content": "Reply with 'Orchestration Service is working!' in German" + "content": "Hello there!" } ], "llm": { @@ -17,7 +17,7 @@ "index": 0, "message": { "role": "assistant", - "content": "Orchestration Service funktioniert!" + "content": "General Kenobi!" }, "finish_reason": "stop" } @@ -27,6 +27,30 @@ "prompt_tokens": 19, "total_tokens": 26 } + }, + "input_filtering": { + "message": "Input filter passed successfully.", + "data": { + "original_service_response": { + "Hate": 0, + "SelfHarm": 0, + "Sexual": 0, + "Violence": 2 + }, + "checked_text": "Hello there!" + } + }, + "output_filtering": { + "message": "Output filter passed successfully.", + "data": { + "original_service_response": { + "Hate": 0, + "SelfHarm": 0, + "Sexual": 0, + "Violence": 2 + }, + "checked_text": "General Kenobi!" + } } }, "orchestration_result": { @@ -39,7 +63,7 @@ "index": 0, "message": { "role": "assistant", - "content": "Orchestration Service funktioniert!" + "content": "General Kenobi!" }, "finish_reason": "stop" } diff --git a/orchestration/src/test/resources/templatingRequest.json b/orchestration/src/test/resources/templatingRequest.json deleted file mode 100644 index 39f9a3ba..00000000 --- a/orchestration/src/test/resources/templatingRequest.json +++ /dev/null @@ -1,26 +0,0 @@ -{ - "orchestration_config" : { - "module_configurations" : { - "llm_module_config" : { - "model_name" : "gpt-35-turbo-16k", - "model_params" : { - "frequency_penalty" : 0, - "presence_penalty" : 0, - "temperature" : 0.1, - "max_tokens" : 50 - }, - "model_version" : "latest" - }, - "templating_module_config" : { - "template" : [ { - "role" : "user", - "content" : "{{?input}}" - } ] - } - } - }, - "input_params" : { - "input" : "Reply with 'Orchestration Service is working!' in German" - }, - "messages_history" : [ ] -} \ No newline at end of file diff --git a/sample-code/spring-ai-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java b/sample-code/spring-ai-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java index f1d86b6a..7c5548fc 100644 --- a/sample-code/spring-ai-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java +++ b/sample-code/spring-ai-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java @@ -28,7 +28,8 @@ class OrchestrationController { // uses defaults from application.yaml @Autowired OrchestrationChatModel client; - OrchestrationChatModel clientWithoutDefaults = new OrchestrationChatModel(new OrchestrationClient()); + OrchestrationChatModel clientWithoutDefaults = + new OrchestrationChatModel(new OrchestrationClient().withMaskingConfig()); @GetMapping("/completion") ChatResponse completion() {