diff --git a/core/pom.xml b/core/pom.xml index 1d0f2ce3..fb55f23c 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -159,6 +159,8 @@ false + + true diff --git a/orchestration/pom.xml b/orchestration/pom.xml index 196724c2..ab34609b 100644 --- a/orchestration/pom.xml +++ b/orchestration/pom.xml @@ -33,32 +33,88 @@ - com.sap.cloud.sdk.datamodel - openapi-core + com.sap.ai.sdk + core + - org.springframework - spring-core + com.sap.cloud.sdk.cloudplatform + cloudplatform-connectivity - org.springframework - spring-web + com.sap.cloud.sdk.cloudplatform + connectivity-apache-httpclient5 + - com.sap.cloud.sdk.cloudplatform - cloudplatform-connectivity + org.apache.httpcomponents.core5 + httpcore5 + + + org.apache.httpcomponents.client5 + httpclient5 com.google.code.findbugs jsr305 - com.google.guava - guava + com.fasterxml.jackson.core + jackson-annotations com.fasterxml.jackson.core - jackson-annotations + jackson-core + + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.datatype + jackson-datatype-jsr310 + + + io.vavr + vavr + + + org.slf4j + slf4j-api + + + + + org.springframework + spring-web + + + + + org.springframework.ai + spring-ai-core + true + + + org.springframework + spring-context + true + + + org.springframework.boot + spring-boot + true + + + org.springframework.boot + spring-boot-autoconfigure + true + + + + org.projectlombok + lombok + provided @@ -71,19 +127,14 @@ wiremock test - - org.apache.httpcomponents.core5 - httpcore5 - test - org.assertj assertj-core test - com.sap.ai.sdk - core + org.mockito + mockito-core test @@ -114,6 +165,7 @@ openapi-generator-maven-plugin ${cloud-sdk.version} + true ${project.basedir}/src/main/java released true diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/AzureContentFilter.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/AzureContentFilter.java new file mode 100644 index 00000000..9cc7e5c3 --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/AzureContentFilter.java @@ -0,0 +1,48 @@ +package com.sap.ai.sdk.orchestration; + +import static com.sap.ai.sdk.orchestration.client.model.AzureThreshold.fromValue; + +import com.sap.ai.sdk.orchestration.client.model.AzureContentSafety; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import lombok.Data; +import lombok.RequiredArgsConstructor; +import lombok.experimental.Accessors; + +@Data +@Accessors(fluent = true) +public class AzureContentFilter implements ContentFilter { + @Nullable private Sensitivity hate; + @Nullable private Sensitivity selfHarm; + @Nullable private Sensitivity sexual; + @Nullable private Sensitivity violence; + + @RequiredArgsConstructor + public enum Sensitivity { + HIGH(0), + MEDIUM(2), + LOW(4); + + private final int value; + } + + @Nonnull + com.sap.ai.sdk.orchestration.client.model.FilterConfig toFilterConfigDTO() { + var dto = AzureContentSafety.create(); + if (hate != null) { + dto.hate(fromValue(hate.value)); + } + if (selfHarm != null) { + dto.selfHarm(fromValue(selfHarm.value)); + } + if (sexual != null) { + dto.sexual(fromValue(sexual.value)); + } + if (violence != null) { + dto.violence(fromValue(violence.value)); + } + return com.sap.ai.sdk.orchestration.client.model.FilterConfig.create() + .type(com.sap.ai.sdk.orchestration.client.model.FilterConfig.TypeEnum.AZURE_CONTENT_SAFETY) + .config(dto); + } +} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/ContentFilter.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/ContentFilter.java new file mode 100644 index 00000000..c0703376 --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/ContentFilter.java @@ -0,0 +1,3 @@ +package com.sap.ai.sdk.orchestration; + +public interface ContentFilter {} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/DefaultOrchestrationConfig.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/DefaultOrchestrationConfig.java new file mode 100644 index 00000000..6b9cda11 --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/DefaultOrchestrationConfig.java @@ -0,0 +1,127 @@ +package com.sap.ai.sdk.orchestration; + +import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig; +import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig; +import io.vavr.control.Option; +import javax.annotation.Nonnull; +import lombok.AccessLevel; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.Setter; +import lombok.ToString; + +@Data +@Setter(AccessLevel.PRIVATE) +@RequiredArgsConstructor(access = AccessLevel.PRIVATE) +public class DefaultOrchestrationConfig> + implements OrchestrationConfig { + + @Nonnull private Option llmConfig = Option.none(); + @Nonnull private Option template = Option.none(); + @Nonnull private Option maskingConfig = Option.none(); + @Nonnull private Option inputContentFilter = Option.none(); + @Nonnull private Option outputContentFilter = Option.none(); + + @EqualsAndHashCode.Exclude + @ToString.Exclude + @Getter(AccessLevel.NONE) + @Nonnull + private final T wrapper; + + @SuppressWarnings("unchecked") + private DefaultOrchestrationConfig() { + wrapper = (T) this; + } + + /** + * Create a new instance of {@link DefaultOrchestrationConfig} to delegate to. This is useful when + * exposing the {@link OrchestrationConfig} in other objects, without re-implementing it. To + * maintain fluent API usage, the given wrapper object will be returned by the fluent methods, + * instead of this instance. + * + * @param wrapper The wrapper that delegates to this object. + * @param The type of the wrapper object. + * @return The new instance. + * @see #standalone() + */ + @Nonnull + public static > DefaultOrchestrationConfig asDelegateFor( + @Nonnull final T wrapper) { + return new DefaultOrchestrationConfig<>(wrapper); + } + + /** + * Create an implementation without any object delegating to it. The fluent API will return this + * object itself. + * + * @return The new instance. + * @see #asDelegateFor(OrchestrationConfig) + */ + @Nonnull + public static DefaultOrchestrationConfig standalone() { + return new DefaultOrchestrationConfig<>(); + } + + @Nonnull + @Override + public T withLlmConfig(@Nonnull final LLMModuleConfig llm) { + this.llmConfig = Option.some(llm); + return wrapper; + } + + @Nonnull + @Override + public T withTemplate(@Nonnull final TemplatingModuleConfig template) { + this.template = Option.some(template); + return wrapper; + } + + @Nonnull + @Override + public T withMaskingConfig(@Nonnull final MaskingConfig maskingConfig) { + this.maskingConfig = Option.some(maskingConfig); + return wrapper; + } + + @Nonnull + @Override + public T withInputContentFilter(@Nonnull final ContentFilter filter) { + this.inputContentFilter = Option.some(filter); + return wrapper; + } + + @Nonnull + @Override + public T withOutputContentFilter(@Nonnull final ContentFilter filter) { + this.outputContentFilter = Option.some(filter); + return wrapper; + } + + /** + * Copy the configuration into the given target configuration. The copy is + * shallow and does not override any existing configuration. + * + *

This has two main use cases: + * + *

    + *
  1. Duplicating a config + *
  2. Applying defaults to a config + *
+ * + * @param source The source configuration to copy from. + * @return This (delegate) object. + */ + @Nonnull + public DefaultOrchestrationConfig copyFrom(@Nonnull final OrchestrationConfig source) { + llmConfig.orElse(source::getLlmConfig).forEach(this::withLlmConfig); + template.orElse(source::getTemplate).forEach(this::withTemplate); + maskingConfig.orElse(source::getMaskingConfig).forEach(this::withMaskingConfig); + inputContentFilter.orElse(source::getInputContentFilter).forEach(this::withInputContentFilter); + outputContentFilter + .orElse(source::getOutputContentFilter) + .forEach(this::withOutputContentFilter); + return this; + } +} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/DpiMaskingConfig.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/DpiMaskingConfig.java new file mode 100644 index 00000000..7da4e8f2 --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/DpiMaskingConfig.java @@ -0,0 +1,61 @@ +package com.sap.ai.sdk.orchestration; + +import static com.sap.ai.sdk.orchestration.client.model.MaskingProviderConfig.MethodEnum.ANONYMIZATION; +import static com.sap.ai.sdk.orchestration.client.model.MaskingProviderConfig.MethodEnum.PSEUDONYMIZATION; +import static com.sap.ai.sdk.orchestration.client.model.MaskingProviderConfig.TypeEnum.SAP_DATA_PRIVACY_INTEGRATION; + +import com.sap.ai.sdk.orchestration.client.model.DPIEntities; +import com.sap.ai.sdk.orchestration.client.model.DPIEntityConfig; +import com.sap.ai.sdk.orchestration.client.model.MaskingProviderConfig; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import io.vavr.control.Option; +import lombok.AccessLevel; +import lombok.RequiredArgsConstructor; +import lombok.Value; + +@Value +@RequiredArgsConstructor(access = AccessLevel.PRIVATE) +public class DpiMaskingConfig implements MaskingConfig { + @Nonnull MaskingProviderConfig.MethodEnum maskingMethod; + @Nonnull List entities; + + @Nonnull + public static Builder anonymization() { + return new DpiMaskingConfig.Builder(ANONYMIZATION); + } + + @Nonnull + public static Builder pseudonymization() { + return new DpiMaskingConfig.Builder(PSEUDONYMIZATION); + } + + @Nonnull + MaskingProviderConfig toMaskingProviderDTO() { + var entities = this.entities.stream().map(it -> DPIEntityConfig.create().type(it)).toList(); + return MaskingProviderConfig.create() + .type(SAP_DATA_PRIVACY_INTEGRATION) + .method(maskingMethod) + .entities(entities); + } + + @RequiredArgsConstructor(access = AccessLevel.PRIVATE) + public static class Builder { + private final MaskingProviderConfig.MethodEnum maskingMethod; + + @Nonnull + public DpiMaskingConfig withEntities(@Nonnull final DPIEntities entity, @Nullable final DPIEntities... entities) { + var entitiesList = new ArrayList(); + entitiesList.add(entity); + if ( entities != null) { + entitiesList.addAll(Arrays.asList(entities)); + } + return new DpiMaskingConfig(maskingMethod, entitiesList); + } + } +} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/MaskingConfig.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/MaskingConfig.java new file mode 100644 index 00000000..21f48ab0 --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/MaskingConfig.java @@ -0,0 +1,3 @@ +package com.sap.ai.sdk.orchestration; + +public interface MaskingConfig {} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/ModuleConfigFactory.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/ModuleConfigFactory.java new file mode 100644 index 00000000..561f9461 --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/ModuleConfigFactory.java @@ -0,0 +1,75 @@ +package com.sap.ai.sdk.orchestration; + +import com.sap.ai.sdk.orchestration.client.model.ChatMessage; +import com.sap.ai.sdk.orchestration.client.model.FilteringConfig; +import com.sap.ai.sdk.orchestration.client.model.FilteringModuleConfig; +import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig; +import com.sap.ai.sdk.orchestration.client.model.MaskingModuleConfig; +import com.sap.ai.sdk.orchestration.client.model.ModuleConfigs; +import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig; +import io.vavr.control.Option; +import java.util.ArrayList; +import java.util.List; +import javax.annotation.Nonnull; + +class ModuleConfigFactory { + @Nonnull + static ModuleConfigs toModuleConfigDTO( + @Nonnull final OrchestrationConfig config, @Nonnull final List messages) { + LLMModuleConfig llm = + config + .getLlmConfig() + .getOrElseThrow(() -> new IllegalStateException("LLM module config is required")); + + /* + * Currently, we have to merge the prompt into the template configuration. + * This works around the limitation that the template config isn't optional. + * This comes at the risk that the prompt unintentionally contains the templating pattern "{{? .. }}". + * In this case, the request will fail, since the templating module will try to resolve the parameter. + * To be fixed with https://github.tools.sap/AI/llm-orchestration/issues/662 + */ + Option maybeTemplate = config.getTemplate(); + final var messagesWithPrompt = new ArrayList<>(messages); + maybeTemplate.map(TemplatingModuleConfig::getTemplate).forEach(messagesWithPrompt::addAll); + + if (messagesWithPrompt.isEmpty()) { + throw new IllegalStateException( + "A prompt is required. Pass at least one message or configure the templating module."); + } + TemplatingModuleConfig template = TemplatingModuleConfig.create().template(messagesWithPrompt); + maybeTemplate.map(TemplatingModuleConfig::getDefaults).forEach(template::defaults); + + ModuleConfigs dto = + ModuleConfigs.create().llmModuleConfig(llm).templatingModuleConfig(template); + + config + .getMaskingConfig() + .filter(DpiMaskingConfig.class::isInstance) + .map(DpiMaskingConfig.class::cast) + .map(DpiMaskingConfig::toMaskingProviderDTO) + .map(it -> MaskingModuleConfig.create().maskingProviders(it)) + .forEach(dto::maskingModuleConfig); + + var maybeInputFilter = config.getInputContentFilter(); + var maybeOutputFilter = config.getOutputContentFilter(); + + if (maybeInputFilter.isDefined() || maybeOutputFilter.isDefined()) { + var filter = FilteringModuleConfig.create(); + maybeInputFilter + .filter(AzureContentFilter.class::isInstance) + .map(AzureContentFilter.class::cast) + .map(AzureContentFilter::toFilterConfigDTO) + .map(it -> FilteringConfig.create().filters(it)) + .forEach(filter::input); + maybeOutputFilter + .filter(AzureContentFilter.class::isInstance) + .map(AzureContentFilter.class::cast) + .map(AzureContentFilter::toFilterConfigDTO) + .map(it -> FilteringConfig.create().filters(it)) + .forEach(filter::output); + dto = dto.filteringModuleConfig(filter); + } + + return dto; + } +} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java new file mode 100644 index 00000000..0df9489f --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java @@ -0,0 +1,120 @@ +package com.sap.ai.sdk.orchestration; + +import com.fasterxml.jackson.annotation.JsonAutoDetect; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.PropertyAccessor; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; +import com.sap.ai.sdk.core.Core; +import com.sap.ai.sdk.orchestration.client.model.CompletionPostRequest; +import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse; +import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor; +import com.sap.cloud.sdk.cloudplatform.connectivity.HttpDestination; +import java.io.IOException; +import java.util.stream.Stream; +import javax.annotation.Nonnull; +import lombok.AllArgsConstructor; +import lombok.experimental.Delegate; +import org.apache.hc.client5.http.classic.methods.HttpPost; +import org.apache.hc.core5.http.ContentType; +import org.apache.hc.core5.http.io.entity.StringEntity; +import org.apache.hc.core5.http.message.BasicClassicHttpRequest; +import org.springframework.http.converter.json.Jackson2ObjectMapperBuilder; + +@AllArgsConstructor +public class OrchestrationClient implements OrchestrationConfig { + static final ObjectMapper JACKSON; + + static { + JACKSON = + new Jackson2ObjectMapperBuilder() + .modules(new JavaTimeModule()) + .visibility(PropertyAccessor.GETTER, JsonAutoDetect.Visibility.NONE) + .visibility(PropertyAccessor.SETTER, JsonAutoDetect.Visibility.NONE) + .serializationInclusion(JsonInclude.Include.NON_NULL) + .build(); + } + + private interface IDelegate extends OrchestrationConfig {} + + @Delegate(types = IDelegate.class) + @Nonnull + private final DefaultOrchestrationConfig clientConfig = + DefaultOrchestrationConfig.asDelegateFor(this); + + @Nonnull private final HttpDestination destination; + + public OrchestrationClient() { + // TODO: use AiCoreService after refactoring + this.destination = Core.getDestinationForDeployment("db1d64d9f06be467", "default").asHttp(); + } + + /** + * Generate a completion for the given user prompt. + * + * @param userPrompt a text message. + * @return the completion output + * @throws OrchestrationClientException if the request fails + */ + @Nonnull + public String chatCompletion(@Nonnull final String userPrompt) + throws OrchestrationClientException { + var response = chatCompletion(new OrchestrationPrompt(userPrompt)); + + var choice = response.getOrchestrationResult().getChoices().get(0); + if (choice.getFinishReason().equalsIgnoreCase("content_filter")) { + throw new OrchestrationClientException("Output content filter triggered"); + } + return choice.getMessage().getContent(); + } + + /** + * Generate a completion for the given prompt. + * + * @param prompt the prompt, including messages and other parameters. + * @return the completion output + * @throws OrchestrationClientException if the request fails + */ + @Nonnull + public CompletionPostResponse chatCompletion(@Nonnull final OrchestrationPrompt prompt) + throws OrchestrationClientException { + var dto = prompt.toCompletionPostRequestDTO(clientConfig); + + return executeRequest(dto); + } + + @Nonnull + public Stream streamChatCompletion(@Nonnull final String prompt) + throws OrchestrationClientException { + throw new RuntimeException("Not implemented"); + } + + @Nonnull + public Stream streamChatCompletionDelta(@Nonnull final OrchestrationPrompt prompt) + throws OrchestrationClientException { + throw new RuntimeException("Not implemented"); + } + + @SuppressWarnings("UnstableApiUsage") + @Nonnull + protected CompletionPostResponse executeRequest(@Nonnull final CompletionPostRequest request) { + final var client = ApacheHttpClient5Accessor.getHttpClient(destination); + // TODO: update after AiCoreService refactoring + final BasicClassicHttpRequest postRequest = + new HttpPost("/v2/inference/deployments/db1d64d9f06be467/completion"); + try { + final var json = JACKSON.writeValueAsString(request); + postRequest.setEntity(new StringEntity(json, ContentType.APPLICATION_JSON)); + } catch (final JsonProcessingException e) { + throw new OrchestrationClientException("Failed to serialize request parameters", e); + } + + try { + return client.execute( + postRequest, new OrchestrationResponseHandler<>(CompletionPostResponse.class)); + } catch (IOException e) { + throw new OrchestrationClientException("Failed to execute request", e); + } + } +} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClientException.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClientException.java new file mode 100644 index 00000000..749f9474 --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClientException.java @@ -0,0 +1,11 @@ +package com.sap.ai.sdk.orchestration; + +public class OrchestrationClientException extends RuntimeException { + public OrchestrationClientException(String msg, Throwable cause) { + super(msg, cause); + } + + public OrchestrationClientException(String msg) { + super(msg); + } +} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationConfig.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationConfig.java new file mode 100644 index 00000000..4064b2c2 --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationConfig.java @@ -0,0 +1,45 @@ +package com.sap.ai.sdk.orchestration; + +import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig; +import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig; +import io.vavr.control.Option; +import javax.annotation.Nonnull; + +/** + * Represents the configuration for the orchestration service. Allows for configuring the different + * modules of the orchestration service via a fluent API. + * + * @param Type of the specific implementation to make a fluent API possible. + */ +public interface OrchestrationConfig> { + + @Nonnull + Option getLlmConfig(); + + @Nonnull + Option getTemplate(); + + @Nonnull + Option getMaskingConfig(); + + @Nonnull + Option getInputContentFilter(); + + @Nonnull + Option getOutputContentFilter(); + + @Nonnull + T withLlmConfig(@Nonnull final LLMModuleConfig llm); + + @Nonnull + T withTemplate(@Nonnull final TemplatingModuleConfig template); + + @Nonnull + T withMaskingConfig(@Nonnull final MaskingConfig maskingConfig); + + @Nonnull + T withInputContentFilter(@Nonnull final ContentFilter filter); + + @Nonnull + T withOutputContentFilter(@Nonnull final ContentFilter filter); +} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationPrompt.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationPrompt.java new file mode 100644 index 00000000..aff6b082 --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationPrompt.java @@ -0,0 +1,52 @@ +package com.sap.ai.sdk.orchestration; + +import com.sap.ai.sdk.orchestration.client.model.ChatMessage; +import com.sap.ai.sdk.orchestration.client.model.CompletionPostRequest; +import java.util.List; +import java.util.Map; +import javax.annotation.Nonnull; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.Value; +import lombok.experimental.Delegate; + +@Value +@AllArgsConstructor +public class OrchestrationPrompt implements OrchestrationConfig { + @Nonnull List messages; + @Nonnull Map templateParameters; + + private interface IDelegate extends OrchestrationConfig {} + + @Getter(AccessLevel.NONE) + @Delegate(types = IDelegate.class) + @Nonnull + DefaultOrchestrationConfig delegate = + DefaultOrchestrationConfig.asDelegateFor(this); + + public OrchestrationPrompt(@Nonnull final String message) { + this(List.of(ChatMessage.create().role("user").content(message)), Map.of()); + } + + public OrchestrationPrompt(@Nonnull final List messagesHistory) { + this(messagesHistory, Map.of()); + } + + public OrchestrationPrompt(@Nonnull final Map inputParams) { + this(List.of(), inputParams); + } + + @Nonnull + CompletionPostRequest toCompletionPostRequestDTO(@Nonnull final OrchestrationConfig defaults) { + // duplicate the prompt config, then apply the defaults to the copy + // that way this prompt remains unchanged and can be reused + var config = DefaultOrchestrationConfig.standalone().copyFrom(this).copyFrom(defaults); + var moduleConfigDTO = ModuleConfigFactory.toModuleConfigDTO(config, messages); + return CompletionPostRequest.create() + .orchestrationConfig( + com.sap.ai.sdk.orchestration.client.model.OrchestrationConfig.create() + .moduleConfigurations(moduleConfigDTO)) + .inputParams(templateParameters); + } +} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationResponseHandler.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationResponseHandler.java new file mode 100644 index 00000000..9d062406 --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationResponseHandler.java @@ -0,0 +1,123 @@ +package com.sap.ai.sdk.orchestration; + +import static com.sap.ai.sdk.orchestration.OrchestrationClient.JACKSON; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.sap.ai.sdk.orchestration.client.model.ErrorResponse; +import io.vavr.control.Try; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import javax.annotation.Nonnull; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.apache.hc.core5.http.ClassicHttpResponse; +import org.apache.hc.core5.http.ContentType; +import org.apache.hc.core5.http.HttpEntity; +import org.apache.hc.core5.http.ParseException; +import org.apache.hc.core5.http.io.HttpClientResponseHandler; +import org.apache.hc.core5.http.io.entity.EntityUtils; + +@Slf4j +@RequiredArgsConstructor +class OrchestrationResponseHandler implements HttpClientResponseHandler { + // TODO reduce code duplication with OpenAI + @Nonnull private final Class responseType; + + /** + * Processes a {@link ClassicHttpResponse} and returns some value corresponding to that response. + * + * @param response The response to process + * @return A model class instantiated from the response + * @throws OrchestrationClientException in case of a problem or the connection was aborted + */ + @Override + public T handleResponse(@Nonnull final ClassicHttpResponse response) + throws OrchestrationClientException { + if (response.getCode() >= 300) { + buildExceptionAndThrow(response); + } + return parseResponse(response); + } + + // The InputStream of the HTTP entity is closed by EntityUtils.toString + @SuppressWarnings("PMD.CloseResource") + @Nonnull + private T parseResponse(@Nonnull final ClassicHttpResponse response) + throws OrchestrationClientException { + final HttpEntity responseEntity = response.getEntity(); + if (responseEntity == null) { + throw new OrchestrationClientException("Response from Orchestration service was empty."); + } + final var content = getContent(responseEntity); + try { + return JACKSON.readValue(content, responseType); + } catch (final JsonProcessingException e) { + log.error("Failed to parse the following response from orchestration service: {}", content); + throw new OrchestrationClientException( + "Failed to parse response from orchestration service", e); + } + } + + @Nonnull + private static String getContent(@Nonnull final HttpEntity entity) { + try { + return EntityUtils.toString(entity, StandardCharsets.UTF_8); + } catch (IOException | ParseException e) { + throw new OrchestrationClientException("Failed to read response content.", e); + } + } + + // The InputStream of the HTTP entity is closed by EntityUtils.toString + @SuppressWarnings("PMD.CloseResource") + static void buildExceptionAndThrow(@Nonnull final ClassicHttpResponse response) + throws OrchestrationClientException { + final var exception = + new OrchestrationClientException( + "Request to orchestration service failed with status %s %s" + .formatted(response.getCode(), response.getReasonPhrase())); + final var entity = response.getEntity(); + if (entity == null) { + throw exception; + } + final var maybeContent = Try.of(() -> getContent(entity)); + if (maybeContent.isFailure()) { + exception.addSuppressed(maybeContent.getCause()); + throw exception; + } + final var content = maybeContent.get(); + if (content.isBlank()) { + throw exception; + } + + log.error( + "The orchestration service responded with an HTTP error and the following content: {}", + content); + final var contentType = ContentType.parse(entity.getContentType()); + if (!ContentType.APPLICATION_JSON.isSameMimeType(contentType)) { + throw exception; + } + + parseErrorAndThrow(content, exception); + } + + /** + * Parse the error response and throw an exception. + * + * @param errorResponse the error response, most likely a JSON of {@link ErrorResponse}. + * @param baseException a base exception to add the error message to. + */ + static void parseErrorAndThrow( + @Nonnull final String errorResponse, + @Nonnull final OrchestrationClientException baseException) + throws OrchestrationClientException { + final var maybeError = Try.of(() -> JACKSON.readValue(errorResponse, ErrorResponse.class)); + if (maybeError.isFailure()) { + baseException.addSuppressed(maybeError.getCause()); + throw baseException; + } + + throw new OrchestrationClientException( + "%s and error message: '%s'" + .formatted(baseException.getMessage(), maybeError.get().getMessage())); + } +} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/client/OrchestrationCompletionApi.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/client/OrchestrationCompletionApi.java deleted file mode 100644 index ab4786d7..00000000 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/client/OrchestrationCompletionApi.java +++ /dev/null @@ -1,97 +0,0 @@ -package com.sap.ai.sdk.orchestration.client; - -import com.google.common.annotations.Beta; -import com.sap.ai.sdk.orchestration.client.model.CompletionPostRequest; -import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse; -import com.sap.cloud.sdk.cloudplatform.connectivity.Destination; -import com.sap.cloud.sdk.services.openapi.apiclient.ApiClient; -import com.sap.cloud.sdk.services.openapi.core.AbstractOpenApiService; -import com.sap.cloud.sdk.services.openapi.core.OpenApiRequestException; -import java.util.List; -import javax.annotation.Nonnull; -import org.springframework.core.ParameterizedTypeReference; -import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpMethod; -import org.springframework.http.MediaType; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; -import org.springframework.web.util.UriComponentsBuilder; - -/** - * Internal Orchestration Service API in version 0.0.1. - * - *

SAP AI Core - Orchestration Service API - */ -public class OrchestrationCompletionApi extends AbstractOpenApiService { - /** - * Instantiates this API class to invoke operations on the Internal Orchestration Service API. - * - * @param httpDestination The destination that API should be used with - */ - public OrchestrationCompletionApi(@Nonnull final Destination httpDestination) { - super(httpDestination); - } - - /** - * Instantiates this API class to invoke operations on the Internal Orchestration Service API - * based on a given {@link ApiClient}. - * - * @param apiClient ApiClient to invoke the API on - */ - @Beta - public OrchestrationCompletionApi(@Nonnull final ApiClient apiClient) { - super(apiClient); - } - - /** - * 200 - Successful response - * - *

400 - Bad Request - * - *

0 - Common Error - * - * @param completionPostRequest The value for the parameter completionPostRequest - * @return CompletionPostResponse - * @throws OpenApiRequestException if an error occurs while attempting to invoke the API - */ - @Nonnull - public CompletionPostResponse orchestrationV1EndpointsCreate( - @Nonnull final CompletionPostRequest completionPostRequest) throws OpenApiRequestException { - final Object localVarPostBody = completionPostRequest; - - // verify the required parameter 'completionPostRequest' is set - if (completionPostRequest == null) { - throw new OpenApiRequestException( - "Missing the required parameter 'completionPostRequest' when calling orchestrationV1EndpointsCreate"); - } - - final String localVarPath = UriComponentsBuilder.fromPath("/completion").build().toUriString(); - - final MultiValueMap localVarQueryParams = - new LinkedMultiValueMap(); - final HttpHeaders localVarHeaderParams = new HttpHeaders(); - final MultiValueMap localVarFormParams = - new LinkedMultiValueMap(); - - final String[] localVarAccepts = {"application/json"}; - final List localVarAccept = apiClient.selectHeaderAccept(localVarAccepts); - final String[] localVarContentTypes = {"application/json"}; - final MediaType localVarContentType = apiClient.selectHeaderContentType(localVarContentTypes); - - final String[] localVarAuthNames = new String[] {}; - - final ParameterizedTypeReference localVarReturnType = - new ParameterizedTypeReference() {}; - return apiClient.invokeAPI( - localVarPath, - HttpMethod.POST, - localVarQueryParams, - localVarPostBody, - localVarHeaderParams, - localVarFormParams, - localVarAccept, - localVarContentType, - localVarAuthNames, - localVarReturnType); - } -} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/client/OrchestrationHealthzApi.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/client/OrchestrationHealthzApi.java deleted file mode 100644 index 6b6054c5..00000000 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/client/OrchestrationHealthzApi.java +++ /dev/null @@ -1,85 +0,0 @@ -package com.sap.ai.sdk.orchestration.client; - -import com.google.common.annotations.Beta; -import com.sap.cloud.sdk.cloudplatform.connectivity.Destination; -import com.sap.cloud.sdk.services.openapi.apiclient.ApiClient; -import com.sap.cloud.sdk.services.openapi.core.AbstractOpenApiService; -import com.sap.cloud.sdk.services.openapi.core.OpenApiRequestException; -import java.util.List; -import javax.annotation.Nonnull; -import org.springframework.core.ParameterizedTypeReference; -import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpMethod; -import org.springframework.http.MediaType; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; -import org.springframework.web.util.UriComponentsBuilder; - -/** - * Internal Orchestration Service API in version 0.0.1. - * - *

SAP AI Core - Orchestration Service API - */ -public class OrchestrationHealthzApi extends AbstractOpenApiService { - /** - * Instantiates this API class to invoke operations on the Internal Orchestration Service API. - * - * @param httpDestination The destination that API should be used with - */ - public OrchestrationHealthzApi(@Nonnull final Destination httpDestination) { - super(httpDestination); - } - - /** - * Instantiates this API class to invoke operations on the Internal Orchestration Service API - * based on a given {@link ApiClient}. - * - * @param apiClient ApiClient to invoke the API on - */ - @Beta - public OrchestrationHealthzApi(@Nonnull final ApiClient apiClient) { - super(apiClient); - } - - /** - * 200 - Service is up and running. - * - *

503 - Service is unavailable. - * - * @return String - * @throws OpenApiRequestException if an error occurs while attempting to invoke the API - */ - @Nonnull - public String orchestrationV1EndpointsHealthz() throws OpenApiRequestException { - final Object localVarPostBody = null; - - final String localVarPath = UriComponentsBuilder.fromPath("/healthz").build().toUriString(); - - final MultiValueMap localVarQueryParams = - new LinkedMultiValueMap(); - final HttpHeaders localVarHeaderParams = new HttpHeaders(); - final MultiValueMap localVarFormParams = - new LinkedMultiValueMap(); - - final String[] localVarAccepts = {"text/plain", "application/json"}; - final List localVarAccept = apiClient.selectHeaderAccept(localVarAccepts); - final String[] localVarContentTypes = {}; - final MediaType localVarContentType = apiClient.selectHeaderContentType(localVarContentTypes); - - final String[] localVarAuthNames = new String[] {}; - - final ParameterizedTypeReference localVarReturnType = - new ParameterizedTypeReference() {}; - return apiClient.invokeAPI( - localVarPath, - HttpMethod.GET, - localVarQueryParams, - localVarPostBody, - localVarHeaderParams, - localVarFormParams, - localVarAccept, - localVarContentType, - localVarAuthNames, - localVarReturnType); - } -} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationAutoConfiguration.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationAutoConfiguration.java new file mode 100644 index 00000000..7f5e27df --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationAutoConfiguration.java @@ -0,0 +1,38 @@ +package com.sap.ai.sdk.orchestration.spring; + +import com.sap.ai.sdk.orchestration.OrchestrationClient; +import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig; +import java.util.Map; +import lombok.extern.slf4j.Slf4j; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; + +@Slf4j +@AutoConfiguration +@EnableConfigurationProperties(OrchestrationSpringProperties.class) +public class OrchestrationAutoConfiguration { + @Bean + @ConditionalOnMissingBean + OrchestrationClient orchestrationClient(OrchestrationSpringProperties properties) { + var llmConfig = + LLMModuleConfig.create().modelName(properties.llm().modelName()).modelParams(Map.of()); + if (properties.llm().modelVersion() != null) { + llmConfig.modelVersion(properties.llm().modelVersion()); + } + return new OrchestrationClient().withLlmConfig(llmConfig); + } + + @Bean + @ConditionalOnMissingBean + OrchestrationChatModel orchestrationChatModel(OrchestrationClient client) { + return new OrchestrationChatModel(client); + } + // + // @Bean + // @ConditionalOnMissingBean + // ChatClient orchestrationChatClient(OrchestrationChatModel model) { + // return ChatClient.create(model); + // } +} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatModel.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatModel.java new file mode 100644 index 00000000..32306e3a --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatModel.java @@ -0,0 +1,63 @@ +package com.sap.ai.sdk.orchestration.spring; + +import com.sap.ai.sdk.orchestration.OrchestrationClient; +import com.sap.ai.sdk.orchestration.OrchestrationPrompt; +import javax.annotation.Nonnull; +import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; + +/** Spring AI integration for the orchestration service. */ +@Slf4j +public class OrchestrationChatModel implements ChatModel { + @Nonnull private final OrchestrationClient client; + + public OrchestrationChatModel(@Nonnull final OrchestrationClient client) { + this.client = client; + } + + @Override + @Nonnull + public ChatOptions getDefaultOptions() { + return new OrchestrationChatOptions(); + } + + @Override + public OrchestrationChatResponse call(@Nonnull final Prompt prompt) { + var orchestrationPrompt = toOrchestrationPrompt(prompt); + var response = client.chatCompletion(orchestrationPrompt); + return OrchestrationChatResponse.fromOrchestrationResponse(response); + } + + @Nonnull + private static OrchestrationPrompt toOrchestrationPrompt(@Nonnull final Prompt prompt) { + var messages = OrchestrationChatOptions.toChatMessages(prompt.getInstructions()); + + var opts = getChatOptions(prompt); + var orchestrationPrompt = new OrchestrationPrompt(messages, opts.getTemplateParameters()); + + opts.getLlmConfig().forEach(orchestrationPrompt::withLlmConfig); + opts.getTemplate().forEach(orchestrationPrompt::withTemplate); + opts.getMaskingConfig().forEach(orchestrationPrompt::withMaskingConfig); + opts.getInputContentFilter().forEach(orchestrationPrompt::withInputContentFilter); + opts.getOutputContentFilter().forEach(orchestrationPrompt::withOutputContentFilter); + + return orchestrationPrompt; + } + + @Nonnull + private static OrchestrationChatOptions getChatOptions(@Nonnull final Prompt prompt) { + if (prompt.getOptions() == null) { + return new OrchestrationChatOptions(); + } + if (prompt.getOptions() instanceof OrchestrationChatOptions opts) { + return opts; + } + // TODO: Should we build the LLM config out of the provided options instead? + log.warn( + "Prompt options are not of type {}. Ignoring provided options.", + OrchestrationChatOptions.class.getSimpleName()); + return new OrchestrationChatOptions(); + } +} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatOptions.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatOptions.java new file mode 100644 index 00000000..23c2bf51 --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatOptions.java @@ -0,0 +1,131 @@ +package com.sap.ai.sdk.orchestration.spring; + +import com.sap.ai.sdk.orchestration.DefaultOrchestrationConfig; +import com.sap.ai.sdk.orchestration.OrchestrationConfig; +import com.sap.ai.sdk.orchestration.client.model.ChatMessage; +import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig; +import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig; +import java.util.List; +import java.util.Map; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import lombok.AccessLevel; +import lombok.Data; +import lombok.Getter; +import lombok.Setter; +import lombok.experimental.Delegate; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.prompt.ChatOptions; + +/** Configuration to be used for orchestration requests. */ +@Data +@Getter(AccessLevel.NONE) +@Setter(AccessLevel.NONE) +public class OrchestrationChatOptions + implements ChatOptions, OrchestrationConfig { + private interface IDelegate extends OrchestrationConfig {} + + @Getter(AccessLevel.NONE) + @Nonnull + @Delegate(types = IDelegate.class) + private final DefaultOrchestrationConfig delegate = + DefaultOrchestrationConfig.asDelegateFor(this); + + @Getter(AccessLevel.PUBLIC) + @Nonnull + private Map templateParameters = Map.of(); + + @Nonnull + public OrchestrationChatOptions withTemplateParameters( + @Nonnull final Map templateParameters) { + this.templateParameters = templateParameters; + return this; + } + + @Nonnull + public OrchestrationChatOptions withTemplate(@Nonnull final List template) { + delegate.withTemplate(TemplatingModuleConfig.create().template(toChatMessages(template))); + return this; + } + + @Nonnull + static List toChatMessages(@Nonnull final List messages) { + return messages.stream() + .map(m -> ChatMessage.create().role(m.getMessageType().getValue()).content(m.getContent())) + .toList(); + } + + // region satisfy the ChatOptions interface, delegating to the LLM config + @Nullable + @Override + public String getModel() { + return delegate.getLlmConfig().map(LLMModuleConfig::getModelName).getOrNull(); + } + + @Nullable + String getModelVersion() { + return delegate.getLlmConfig().map(LLMModuleConfig::getModelVersion).getOrNull(); + } + + @Nullable + @Override + public Double getFrequencyPenalty() { + return getLlmConfigParam("frequencyPenalty"); + } + + @Nullable + @Override + public Integer getMaxTokens() { + return getLlmConfigParam("maxTokens"); + } + + @Nullable + @Override + public Double getPresencePenalty() { + return getLlmConfigParam("presencePenalty"); + } + + @Nullable + @Override + public List getStopSequences() { + return getLlmConfigParam("stopSequences"); + } + + @Nullable + @Override + public Double getTemperature() { + return getLlmConfigParam("temperature"); + } + + @Nullable + @Override + public Integer getTopK() { + return getLlmConfigParam("topK"); + } + + @Nullable + @Override + public Double getTopP() { + return getLlmConfigParam("topP"); + } + + @Override + public OrchestrationChatOptions copy() { + var copy = new OrchestrationChatOptions(); + copy.delegate.copyFrom(this.delegate); + copy.templateParameters.putAll(this.templateParameters); + return copy; + } + + @SuppressWarnings("unchecked") + @Nullable + private T getLlmConfigParam(@Nonnull final String param) { + return delegate + .getLlmConfig() + .map(LLMModuleConfig::getModelParams) + .map(it -> (Map) it) + .map(m -> (T) m.get(param)) + .getOrNull(); + } + // endregion +} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatResponse.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatResponse.java new file mode 100644 index 00000000..1a58bd4b --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatResponse.java @@ -0,0 +1,82 @@ +package com.sap.ai.sdk.orchestration.spring; + +import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse; +import com.sap.ai.sdk.orchestration.client.model.LLMChoice; +import com.sap.ai.sdk.orchestration.client.model.LLMModuleResult; +import com.sap.ai.sdk.orchestration.client.model.ModuleResults; +import com.sap.ai.sdk.orchestration.client.model.TokenUsage; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import javax.annotation.Nonnull; +import lombok.EqualsAndHashCode; +import lombok.Value; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.DefaultUsage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; + +@Value +@EqualsAndHashCode(callSuper = true) +public class OrchestrationChatResponse extends ChatResponse { + @Nonnull ModuleResults moduleResults; + + private OrchestrationChatResponse( + @Nonnull final List generations, + @Nonnull final ChatResponseMetadata metadata, + @Nonnull final ModuleResults moduleResults) { + super(generations, metadata); + this.moduleResults = moduleResults; + } + + @Nonnull + static OrchestrationChatResponse fromOrchestrationResponse( + @Nonnull final CompletionPostResponse response) { + final var generations = toGenerations(response.getOrchestrationResult()); + + final var metadata = toChatResponseMetadata(response.getOrchestrationResult()); + return new OrchestrationChatResponse(generations, metadata, response.getModuleResults()); + } + + @Nonnull + static List toGenerations(@Nonnull final LLMModuleResult result) { + return result.getChoices().stream() + .map(OrchestrationChatResponse::toAssistantMessage) + .map(Generation::new) + .toList(); + } + + @Nonnull + static AssistantMessage toAssistantMessage(@Nonnull final LLMChoice choice) { + Map metadata = new HashMap<>(); + metadata.put("finish_reason", choice.getFinishReason()); + metadata.put("index", choice.getIndex()); + if (!choice.getLogprobs().isEmpty()) { + metadata.put("logprobs", choice.getLogprobs()); + } + return new AssistantMessage(choice.getMessage().getContent(), metadata); + } + + @Nonnull + static ChatResponseMetadata toChatResponseMetadata( + @Nonnull final LLMModuleResult orchestrationResult) { + var metadataBuilder = ChatResponseMetadata.builder(); + + metadataBuilder.withId(orchestrationResult.getId()); + metadataBuilder.withModel(orchestrationResult.getModel()); + metadataBuilder.withKeyValue("object", orchestrationResult.getObject()); + metadataBuilder.withKeyValue("created", orchestrationResult.getCreated()); + metadataBuilder.withUsage(toDefaultUsage(orchestrationResult.getUsage())); + + return metadataBuilder.build(); + } + + @Nonnull + private static DefaultUsage toDefaultUsage(@Nonnull final TokenUsage usage) { + return new DefaultUsage( + usage.getPromptTokens().longValue(), + usage.getCompletionTokens().longValue(), + usage.getTotalTokens().longValue()); + } +} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationSpringProperties.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationSpringProperties.java new file mode 100644 index 00000000..278a1b88 --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/spring/OrchestrationSpringProperties.java @@ -0,0 +1,11 @@ +package com.sap.ai.sdk.orchestration.spring; + +import org.springframework.boot.context.properties.ConfigurationProperties; + +// TODO can this be a static inner class? +@ConfigurationProperties(prefix = OrchestrationSpringProperties.CONFIG_PREFIX) +public record OrchestrationSpringProperties(Llm llm) { + public static final String CONFIG_PREFIX = "com.sap.ai.sdk.orchestration"; + + public record Llm(String modelName, String modelVersion) {} +} diff --git a/orchestration/src/main/resources/META-INF.spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/orchestration/src/main/resources/META-INF.spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports new file mode 100644 index 00000000..920fda0c --- /dev/null +++ b/orchestration/src/main/resources/META-INF.spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -0,0 +1 @@ +com.sap.ai.genai.orchestration.OrchestrationAutoConfiguration \ No newline at end of file diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/DefaultOrchestrationConfigTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/DefaultOrchestrationConfigTest.java new file mode 100644 index 00000000..88e1eae9 --- /dev/null +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/DefaultOrchestrationConfigTest.java @@ -0,0 +1,60 @@ +package com.sap.ai.sdk.orchestration; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; + +import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig; +import java.util.Map; + +import io.vavr.control.Option; +import org.junit.jupiter.api.Test; + +class DefaultOrchestrationConfigTest { + private static final OrchestrationConfig DEFAULT_CONFIG = + DefaultOrchestrationConfig.standalone() + .withLlmConfig(mock(LLMModuleConfig.class)) + .withMaskingConfig(mock(MaskingConfig.class)); + + @Test + void testStandalone() { + var config = DefaultOrchestrationConfig.standalone(); + + assertThat(config.withMaskingConfig(null)).isSameAs(config); + } + + @Test + @SuppressWarnings("unchecked") + void testDelegation() { + var mock = mock(OrchestrationConfig.class); + + var config = DefaultOrchestrationConfig.asDelegateFor(mock); + assertThat(config.withMaskingConfig(null)).isSameAs(mock); + } + + @Test + void testCopy() { + var config = DefaultOrchestrationConfig.standalone(); + var duplicate = config.copyFrom(DEFAULT_CONFIG); + + assertThat(duplicate) + .isEqualTo(DEFAULT_CONFIG) + .hasSameHashCodeAs(DEFAULT_CONFIG) + .isSameAs(config) + .isNotSameAs(DEFAULT_CONFIG); + } + + @Test + void testApplyingDefaults() { + var config = DefaultOrchestrationConfig.standalone(); + var llm = LLMModuleConfig.create().modelName("foo").modelParams(Map.of()); + config.withLlmConfig(llm); + + config.copyFrom(DEFAULT_CONFIG); + + assertThat(config) + .isNotEqualTo(DEFAULT_CONFIG) + .extracting(OrchestrationConfig::getLlmConfig) + .extracting(Option::get) + .isEqualTo(llm); + } +} diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/ModuleConfigFactoryTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/ModuleConfigFactoryTest.java new file mode 100644 index 00000000..49634d20 --- /dev/null +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/ModuleConfigFactoryTest.java @@ -0,0 +1,132 @@ +package com.sap.ai.sdk.orchestration; + +import static com.sap.ai.sdk.orchestration.AzureContentFilter.Sensitivity.HIGH; +import static com.sap.ai.sdk.orchestration.ModuleConfigFactory.toModuleConfigDTO; +import static com.sap.ai.sdk.orchestration.client.model.FilterConfig.TypeEnum.AZURE_CONTENT_SAFETY; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +import com.sap.ai.sdk.orchestration.client.model.ChatMessage; +import com.sap.ai.sdk.orchestration.client.model.DPIEntities; +import com.sap.ai.sdk.orchestration.client.model.DPIEntityConfig; +import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig; +import com.sap.ai.sdk.orchestration.client.model.MaskingProviderConfig; +import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig; +import java.util.List; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class ModuleConfigFactoryTest { + private static final List messages = List.of(mock(ChatMessage.class)); + private DefaultOrchestrationConfig config; + + @BeforeEach + void setUp() { + config = DefaultOrchestrationConfig.standalone(); + config.withLlmConfig(mock(LLMModuleConfig.class)); + } + + @Test + void testThrowsOnMissingConfig() { + config = DefaultOrchestrationConfig.standalone(); + + assertThatThrownBy(() -> toModuleConfigDTO(config, messages)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("LLM module config is required"); + + config.withLlmConfig(mock(LLMModuleConfig.class)); + assertThatThrownBy(() -> toModuleConfigDTO(config, List.of())) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("prompt is required"); + } + + @Test + void testLlmConfig() { + var llmConfig = mock(LLMModuleConfig.class); + + config.withLlmConfig(llmConfig); + + var result = toModuleConfigDTO(config, messages).getLlmModuleConfig(); + assertThat(result).isSameAs(llmConfig); + } + + @Test + void testTemplateIsCreatedFromMessages() { + var result = toModuleConfigDTO(config, messages).getTemplatingModuleConfig(); + + assertThat(result.getTemplate()).containsExactly(messages.get(0)); + assertThat(result.getDefaults()).isNull(); + } + + @Test + void testMessagesAreMergedIntoTemplate() { + var message1 = mock(ChatMessage.class); + var message2 = mock(ChatMessage.class); + config.withTemplate(TemplatingModuleConfig.create().template(List.of(message1))); + + var result = toModuleConfigDTO(config, List.of(message2)).getTemplatingModuleConfig(); + + assertThat(result.getTemplate()).containsExactly(message2, message1); + } + + @Test + void testInputFilter() { + var filter = new AzureContentFilter().hate(HIGH); + config.withInputContentFilter(filter); + + var result = toModuleConfigDTO(config, messages).getFilteringModuleConfig(); + + assertThat(result.getInput().getFilters()).isNotEmpty(); + + var filterDto = result.getInput().getFilters().get(0); + assertThat(filterDto.getType()).isEqualTo(AZURE_CONTENT_SAFETY); + assertThat(filterDto.getConfig().getHate().getValue()).isZero(); + assertThat(filterDto.getConfig().getViolence()).isNull(); + + assertThat(result.getOutput()).isNull(); + } + + @Test + void testOutputFilter() { + var filter = new AzureContentFilter().hate(HIGH); + config.withOutputContentFilter(filter); + + var result = toModuleConfigDTO(config, messages).getFilteringModuleConfig(); + + assertThat(result.getOutput().getFilters()).isNotEmpty(); + var filterDto = result.getOutput().getFilters().get(0); + assertThat(filterDto.getType()).isEqualTo(AZURE_CONTENT_SAFETY); + assertThat(filterDto.getConfig().getHate().getValue()).isZero(); + assertThat(filterDto.getConfig().getViolence()).isNull(); + assertThat(result.getInput()).isNull(); + } + + @Test + void testInputAndOutputFilter() { + var inputFilter = new AzureContentFilter(); + var outputFilter = new AzureContentFilter(); + config.withInputContentFilter(inputFilter); + config.withOutputContentFilter(outputFilter); + + var result = toModuleConfigDTO(config, messages).getFilteringModuleConfig(); + + assertThat(result.getInput().getFilters()).isNotEmpty(); + assertThat(result.getOutput().getFilters()).isNotEmpty(); + } + + @Test + void testMasking() { + var maskingConfig = DpiMaskingConfig.anonymization().withEntities(DPIEntities.ADDRESS); + config.withMaskingConfig(maskingConfig); + + var result = toModuleConfigDTO(config, messages).getMaskingModuleConfig(); + + assertThat(result.getMaskingProviders()) + .isNotEmpty() + .extracting(MaskingProviderConfig::getEntities) + .extracting(it -> it.get(0)) + .extracting(DPIEntityConfig::getType) + .containsOnly(DPIEntities.ADDRESS); + } +} diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationClientTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationClientTest.java new file mode 100644 index 00000000..c96221ea --- /dev/null +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationClientTest.java @@ -0,0 +1,78 @@ +package com.sap.ai.sdk.orchestration; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.sap.ai.sdk.orchestration.client.model.ChatMessage; +import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse; +import com.sap.ai.sdk.orchestration.client.model.LLMChoice; +import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig; +import com.sap.ai.sdk.orchestration.client.model.LLMModuleResult; +import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentMatchers; + +/** + * Test that queries are on the right URL, with the right headers. Also check that the received + * response is parsed correctly in the generated client. + */ +public class OrchestrationClientTest { + private OrchestrationClient client; + + private static final LLMModuleConfig LLM_CONFIG = + LLMModuleConfig.create().modelName("gpt-35-turbo-16k").modelParams(Map.of()); + + @BeforeEach + void setup() { + var destination = DefaultHttpDestination.builder("").build(); + client = spy(new OrchestrationClient(destination).withLlmConfig(LLM_CONFIG)); + } + + @Test + void testSimpleChatCompletion() { + stubResponse("stop"); + + var result = client.chatCompletion("Hello there!"); + assertThat(result).isEqualTo("General Kenobi!"); + + var expected = ChatMessage.create().role("user").content("Hello there!"); + verify(client) + .chatCompletion( + ArgumentMatchers.argThat( + prompt -> prompt.getMessages().contains(expected))); + } + + @Test + void testSimpleChatCompletionThrowsOnOutputContentFilter() { + + stubResponse("content_filter"); + + assertThatThrownBy(() -> client.chatCompletion("foo")) + .isInstanceOf(OrchestrationClientException.class) + .hasMessageContaining("content filter"); + } + + private void stubResponse(String finishReason) { + var response = mock(CompletionPostResponse.class); + var orchestrationResult = mock(LLMModuleResult.class); + var llmChoice = + LLMChoice.create() + .index(0) + .message(ChatMessage.create().role("assistant").content("General Kenobi!")) + .finishReason(finishReason); + + when(orchestrationResult.getChoices()).thenReturn(List.of(llmChoice)); + when(response.getOrchestrationResult()).thenReturn(orchestrationResult); + + doReturn(response).when(client).executeRequest(any()); + } +} diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationResponseHandlerTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationResponseHandlerTest.java new file mode 100644 index 00000000..f77b26de --- /dev/null +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationResponseHandlerTest.java @@ -0,0 +1,82 @@ +package com.sap.ai.sdk.orchestration; + +import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl; +import static com.github.tomakehurst.wiremock.client.WireMock.badRequest; +import static com.github.tomakehurst.wiremock.client.WireMock.ok; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.serverError; +import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; +import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig; +import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor; +import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Cache; +import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination; +import java.util.Map; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** + * Test that queries are on the right URL, with the right headers. Also check that the received + * response is parsed correctly in the generated client. + */ +@SuppressWarnings("UnstableApiUsage") +@WireMockTest +public class OrchestrationResponseHandlerTest { + private OrchestrationClient client; + + private static final LLMModuleConfig LLM_CONFIG = + LLMModuleConfig.create().modelName("gpt-35-turbo-16k").modelParams(Map.of()); + + @BeforeEach + void setup(WireMockRuntimeInfo server) { + var destination = DefaultHttpDestination.builder(server.getHttpBaseUrl()).build(); + client = new OrchestrationClient(destination).withLlmConfig(LLM_CONFIG); + ApacheHttpClient5Accessor.setHttpClientCache(ApacheHttpClient5Cache.DISABLED); + } + + @AfterEach + void reset() { + ApacheHttpClient5Accessor.setHttpClientCache(null); + ApacheHttpClient5Accessor.setHttpClientFactory(null); + } + + @Test + void testSuccessResponse() { + var response = + ok().withBodyFile("serializedResponse.json").withHeader("Content-Type", "application/json"); + stubFor(post(anyUrl()).willReturn(response)); + + var result = client.chatCompletion("Hello there!"); + + assertThat(result).isEqualTo("General Kenobi!"); + } + + @Test + void testGenericErrorHandling() { + stubFor(post(anyUrl()).willReturn(serverError())); + + assertThatThrownBy(() -> client.chatCompletion("Hello World!")) + .isInstanceOf(OrchestrationClientException.class) + .hasMessageContaining("500 Server Error"); + } + + @Test + void testOrchestrationErrorParsing() { + stubFor( + post(anyUrl()) + .willReturn( + badRequest() + .withHeader("Content-Type", "application/json") + .withBodyFile("errorResponse.json"))); + + assertThatThrownBy(() -> client.chatCompletion("Hello World!")) + .isInstanceOf(OrchestrationClientException.class) + .hasMessageContaining("400 Bad Request") + .hasMessageContaining("'orchestration_config' is a required property"); + } +} diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/SerializationTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/SerializationTest.java new file mode 100644 index 00000000..6474fd29 --- /dev/null +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/SerializationTest.java @@ -0,0 +1,139 @@ +package com.sap.ai.sdk.orchestration; + +import static com.sap.ai.sdk.orchestration.AzureContentFilter.Sensitivity.LENIENT; +import static com.sap.ai.sdk.orchestration.AzureContentFilter.Sensitivity.LOW; +import static com.sap.ai.sdk.orchestration.AzureContentFilter.Sensitivity.MEDIUM; +import static com.sap.ai.sdk.orchestration.AzureContentFilter.Sensitivity.HIGH; +import static org.assertj.core.api.Assertions.assertThat; + +import com.fasterxml.jackson.databind.JsonNode; +import com.sap.ai.sdk.orchestration.client.model.ChatMessage; +import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse; +import com.sap.ai.sdk.orchestration.client.model.DPIEntities; +import com.sap.ai.sdk.orchestration.client.model.GenericModuleResult; +import com.sap.ai.sdk.orchestration.client.model.LLMChoice; +import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig; +import com.sap.ai.sdk.orchestration.client.model.LLMModuleResult; +import com.sap.ai.sdk.orchestration.client.model.ModuleResults; +import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig; +import com.sap.ai.sdk.orchestration.client.model.TokenUsage; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; + +class SerializationTest { + + @Test + void testSerialization() throws IOException { + var llm = + LLMModuleConfig.create() + .modelName("gpt-35-turbo-16k") + .modelParams(Map.of("temperature", 0.5, "frequency_penalty", 1)); + var template = + TemplatingModuleConfig.create() + .template(List.of(ChatMessage.create().role("user").content("{{?input}}"))) + .defaults(Map.of("input", "Hello World!")); + + var inputFilter = new AzureContentFilter().selfHarm(LOW); + var outputFilter = + new AzureContentFilter() + .hate(HIGH) + .selfHarm(MEDIUM) + .sexual(LOW); + + var masking = + DpiMaskingConfig.anonymization() + .withEntities(DPIEntities.ADDRESS, DPIEntities.IBAN, DPIEntities.LOCATION); + var inputParams = Map.of("input", "Reply with 'Orchestration Service is working!' in German"); + + var dto = + new OrchestrationPrompt(inputParams) + .withLlmConfig(llm) + .withTemplate(template) + .withInputContentFilter(inputFilter) + .withOutputContentFilter(outputFilter) + .withMaskingConfig(masking) + .toCompletionPostRequestDTO(DefaultOrchestrationConfig.standalone()); + + var actual = OrchestrationClient.JACKSON.valueToTree(dto); + + var expected = + OrchestrationClient.JACKSON.readValue( + getClass().getClassLoader().getResource("serializedRequest.json"), JsonNode.class); + + assertThat(actual) + .withRepresentation(it -> ((JsonNode) it).toPrettyString()) + .isEqualTo(expected); + } + + @Test + void testDeserialization() throws IOException { + + var llmResult = + LLMModuleResult.create() + .id("chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj") + ._object("chat.completion") + .created(1721224505) + .model("gpt-35-turbo-16k") + .choices( + List.of( + LLMChoice.create() + .index(0) + .message(ChatMessage.create().role("assistant").content("General Kenobi!")) + .finishReason("stop"))) + .usage(TokenUsage.create().completionTokens(7).promptTokens(19).totalTokens(26)); + + var orchestrationResult = + LLMModuleResult.create() + .id("chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj") + ._object("chat.completion") + .created(1721224505) + .model("gpt-35-turbo-16k") + .choices( + List.of( + LLMChoice.create() + .index(0) + .message(ChatMessage.create().role("assistant").content("General Kenobi!")) + .finishReason("stop"))) + .usage(TokenUsage.create().completionTokens(7).promptTokens(19).totalTokens(26)); + + var inputFilterResult = + GenericModuleResult.create() + .message("Input filter passed successfully.") + .data( + Map.of( + "original_service_response", + Map.of("Hate", 0, "SelfHarm", 0, "Sexual", 0, "Violence", 2), + "checked_text", + "Hello there!")); + var outputFilterResult = + GenericModuleResult.create() + .message("Output filter passed successfully.") + .data( + Map.of( + "original_service_response", + Map.of("Hate", 0, "SelfHarm", 0, "Sexual", 0, "Violence", 2), + "checked_text", + "General Kenobi!")); + var expected = + CompletionPostResponse.create() + .requestId("26ea36b5-c196-4806-a9a6-a686f0c6ad91") + .moduleResults( + ModuleResults.create() + .templating(List.of(ChatMessage.create().role("user").content("Hello there!"))) + .llm(llmResult) + .inputFiltering(inputFilterResult) + .outputFiltering(outputFilterResult) + .inputMasking(null) + .outputUnmasking(List.of())) + .orchestrationResult(orchestrationResult); + + var actual = + OrchestrationClient.JACKSON.readValue( + getClass().getClassLoader().getResourceAsStream("serializedResponse.json"), + CompletionPostResponse.class); + + assertThat(actual).isEqualTo(expected); + } +} diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/client/OrchestrationUnitTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/client/OrchestrationUnitTest.java deleted file mode 100644 index e9a2f79b..00000000 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/client/OrchestrationUnitTest.java +++ /dev/null @@ -1,291 +0,0 @@ -package com.sap.ai.sdk.orchestration.client; - -import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson; -import static com.github.tomakehurst.wiremock.client.WireMock.jsonResponse; -import static com.github.tomakehurst.wiremock.client.WireMock.okJson; -import static com.github.tomakehurst.wiremock.client.WireMock.post; -import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor; -import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; -import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; -import static com.github.tomakehurst.wiremock.client.WireMock.verify; -import static com.sap.ai.sdk.core.Core.getClient; -import static com.sap.ai.sdk.orchestration.client.model.AzureThreshold.NUMBER_0; -import static com.sap.ai.sdk.orchestration.client.model.AzureThreshold.NUMBER_4; -import static org.apache.hc.core5.http.HttpStatus.SC_BAD_REQUEST; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; -import com.github.tomakehurst.wiremock.junit5.WireMockTest; -import com.sap.ai.sdk.orchestration.client.model.AzureContentSafety; -import com.sap.ai.sdk.orchestration.client.model.AzureThreshold; -import com.sap.ai.sdk.orchestration.client.model.ChatMessage; -import com.sap.ai.sdk.orchestration.client.model.CompletionPostRequest; -import com.sap.ai.sdk.orchestration.client.model.FilterConfig; -import com.sap.ai.sdk.orchestration.client.model.FilteringConfig; -import com.sap.ai.sdk.orchestration.client.model.FilteringModuleConfig; -import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig; -import com.sap.ai.sdk.orchestration.client.model.ModuleConfigs; -import com.sap.ai.sdk.orchestration.client.model.OrchestrationConfig; -import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig; -import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination; -import java.io.IOException; -import java.util.List; -import java.util.Map; -import java.util.function.Function; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.springframework.web.client.HttpClientErrorException; - -/** - * Test that queries are on the right URL, with the right headers. Also check that the received - * response is parsed correctly in the generated client. - */ -@WireMockTest -public class OrchestrationUnitTest { - private OrchestrationCompletionApi client; - - private static final LLMModuleConfig LLM_CONFIG = - LLMModuleConfig.create() - .modelName("gpt-35-turbo-16k") - .modelParams( - Map.of( - "max_tokens", 50, - "temperature", 0.1, - "frequency_penalty", 0, - "presence_penalty", 0)); - - private static final Function TEMPLATE_CONFIG = - (TemplatingModuleConfig templatingModuleConfig) -> - CompletionPostRequest.create() - .orchestrationConfig( - OrchestrationConfig.create() - .moduleConfigurations( - ModuleConfigs.create() - .llmModuleConfig(LLM_CONFIG) - .templatingModuleConfig(templatingModuleConfig))) - .inputParams(Map.of()); - - /** - * Creates a config from a filter threshold. The config includes a template and has input and - * output filters - */ - private static final Function FILTERING_CONFIG = - (AzureThreshold filterThreshold) -> { - final var inputParams = - Map.of( - "disclaimer", - "```DISCLAIMER: The area surrounding the apartment is known for prostitutes and gang violence including armed conflicts, gun violence is frequent."); - final var template = - ChatMessage.create() - .role("user") - .content( - "Create a rental posting for subletting my apartment in the downtown area. Keep it short. Make sure to add the following disclaimer to the end. Do not change it! {{?disclaimer}}"); - final var templatingConfig = TemplatingModuleConfig.create().template(template); - - final var filter = - FilterConfig.create() - .type(FilterConfig.TypeEnum.AZURE_CONTENT_SAFETY) - .config( - AzureContentSafety.create() - .hate(filterThreshold) - .selfHarm(filterThreshold) - .sexual(filterThreshold) - .violence(filterThreshold)); - final var filteringConfig = - FilteringModuleConfig.create() - .input(FilteringConfig.create().filters(filter)) - .output(FilteringConfig.create().filters(filter)); - - return CompletionPostRequest.create() - .orchestrationConfig( - OrchestrationConfig.create() - .moduleConfigurations( - ModuleConfigs.create() - .llmModuleConfig(LLM_CONFIG) - .templatingModuleConfig(templatingConfig) - .filteringModuleConfig(filteringConfig))) - .inputParams(inputParams); - }; - - @BeforeEach - void setup(WireMockRuntimeInfo server) { - final DefaultHttpDestination destination = - DefaultHttpDestination.builder(server.getHttpBaseUrl()).build(); - client = new OrchestrationCompletionApi(getClient(destination)); - } - - @Test - void testTemplating() throws IOException { - final String response = - new String( - getClass() - .getClassLoader() - .getResourceAsStream("templatingResponse.json") - .readAllBytes()); - stubFor(post(urlPathEqualTo("/completion")).willReturn(okJson(response))); - - final var template = ChatMessage.create().role("user").content("{{?input}}"); - final var inputParams = - Map.of("input", "Reply with 'Orchestration Service is working!' in German"); - - final var config = - TEMPLATE_CONFIG - .apply(TemplatingModuleConfig.create().template(template)) - .inputParams(inputParams); - - final var result = client.orchestrationV1EndpointsCreate(config); - - assertThat(result.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91"); - assertThat(result.getModuleResults().getTemplating().get(0).getContent()) - .isEqualTo("Reply with 'Orchestration Service is working!' in German"); - assertThat(result.getModuleResults().getTemplating().get(0).getRole()).isEqualTo("user"); - var llm = result.getModuleResults().getLlm(); - assertThat(llm.getId()).isEqualTo("chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj"); - assertThat(llm.getObject()).isEqualTo("chat.completion"); - assertThat(llm.getCreated()).isEqualTo(1721224505); - assertThat(llm.getModel()).isEqualTo("gpt-35-turbo-16k"); - var choices = llm.getChoices(); - assertThat(choices.get(0).getIndex()).isEqualTo(0); - assertThat(choices.get(0).getMessage().getContent()) - .isEqualTo("Orchestration Service funktioniert!"); - assertThat(choices.get(0).getMessage().getRole()).isEqualTo("assistant"); - assertThat(choices.get(0).getFinishReason()).isEqualTo("stop"); - var usage = llm.getUsage(); - assertThat(usage.getCompletionTokens()).isEqualTo(7); - assertThat(usage.getPromptTokens()).isEqualTo(19); - assertThat(usage.getTotalTokens()).isEqualTo(26); - assertThat(result.getOrchestrationResult().getId()) - .isEqualTo("chatcmpl-9lzPV4kLrXjFckOp2yY454wksWBoj"); - assertThat(result.getOrchestrationResult().getObject()).isEqualTo("chat.completion"); - assertThat(result.getOrchestrationResult().getCreated()).isEqualTo(1721224505); - assertThat(result.getOrchestrationResult().getModel()).isEqualTo("gpt-35-turbo-16k"); - choices = result.getOrchestrationResult().getChoices(); - assertThat(choices.get(0).getIndex()).isEqualTo(0); - assertThat(choices.get(0).getMessage().getContent()) - .isEqualTo("Orchestration Service funktioniert!"); - assertThat(choices.get(0).getMessage().getRole()).isEqualTo("assistant"); - assertThat(choices.get(0).getFinishReason()).isEqualTo("stop"); - usage = result.getOrchestrationResult().getUsage(); - assertThat(usage.getCompletionTokens()).isEqualTo(7); - assertThat(usage.getPromptTokens()).isEqualTo(19); - assertThat(usage.getTotalTokens()).isEqualTo(26); - - // verify that null fields are absent from the sent request - final String request = - new String( - getClass() - .getClassLoader() - .getResourceAsStream("templatingRequest.json") - .readAllBytes()); - verify(postRequestedFor(urlPathEqualTo("/completion")).withRequestBody(equalToJson(request))); - } - - @Test - void testTemplatingBadRequest() { - stubFor( - post(urlPathEqualTo("/completion")) - .willReturn( - jsonResponse( - """ - { - "request_id": "51043a32-01f5-429a-b0e7-3a99432e43a4", - "code": 400, - "message": "Missing required parameters: ['input']", - "location": "Module: Templating", - "module_results": {} - } - """, - SC_BAD_REQUEST))); - - final var template = ChatMessage.create().role("user").content("{{?input}}"); - // input params are omitted on purpose to trigger an error - Map inputParams = Map.of(); - - final var config = - TEMPLATE_CONFIG - .apply(TemplatingModuleConfig.create().template(template)) - .inputParams(inputParams); - - assertThatThrownBy(() -> client.orchestrationV1EndpointsCreate(config)) - .isInstanceOf(HttpClientErrorException.class) - .hasMessage( - "400 Bad Request: \"{ \"request_id\": \"51043a32-01f5-429a-b0e7-3a99432e43a4\", \"code\": 400, \"message\": \"Missing required parameters: ['input']\", \"location\": \"Module: Templating\", \"module_results\": {}}\""); - } - - @Test - void testFilteringLoose() throws IOException { - final String response = - new String( - getClass() - .getClassLoader() - .getResourceAsStream("filteringLooseResponse.json") - .readAllBytes()); - stubFor(post(urlPathEqualTo("/completion")).willReturn(okJson(response))); - - final var config = FILTERING_CONFIG.apply(NUMBER_4); - - client.orchestrationV1EndpointsCreate(config); - // the result is asserted in the verify step below - - // verify that null fields are absent from the sent request - final String request = - new String( - getClass() - .getClassLoader() - .getResourceAsStream("filteringLooseRequest.json") - .readAllBytes()); - verify(postRequestedFor(urlPathEqualTo("/completion")).withRequestBody(equalToJson(request))); - } - - @Test - void testFilteringStrict() { - final String response = - """ - {"request_id": "bf6d6792-7adf-4d3c-9368-a73615af8c5a", "code": 400, "message": "Content filtered due to Safety violations. Please modify the prompt and try again.", "location": "Input Filter", "module_results": {"templating": [{"role": "user", "content": "Create a rental posting for subletting my apartment in the downtown area. Keep it short. Make sure to add the following disclaimer to the end. Do not change it! ```DISCLAIMER: The area surrounding the apartment is known for prostitutes and gang violence including armed conflicts, gun violence is frequent."}], "input_filtering": {"message": "Content filtered due to Safety violations. Please modify the prompt and try again.", "data": {"original_service_response": {"Hate": 0, "SelfHarm": 0, "Sexual": 0, "Violence": 2}, "checked_text": "Create a rental posting for subletting my apartment in the downtown area. Keep it short. Make sure to add the following disclaimer to the end. Do not change it! ```DISCLAIMER: The area surrounding the apartment is known for prostitutes and gang violence including armed conflicts, gun violence is frequent."}}}}"""; - stubFor(post(urlPathEqualTo("/completion")).willReturn(jsonResponse(response, SC_BAD_REQUEST))); - - final var config = FILTERING_CONFIG.apply(NUMBER_0); - - assertThatThrownBy(() -> client.orchestrationV1EndpointsCreate(config)) - .isInstanceOf(HttpClientErrorException.class) - .hasMessage("400 Bad Request: \"" + response + "\""); - } - - @Test - void testMessagesHistory() throws IOException { - final String response = - new String( - getClass() - .getClassLoader() - // the response is not asserted in this test - .getResourceAsStream("templatingResponse.json") - .readAllBytes()); - stubFor(post(urlPathEqualTo("/completion")).willReturn(okJson(response))); - - final List messagesHistory = - List.of( - ChatMessage.create().role("user").content("What is the capital of France?"), - ChatMessage.create().role("assistant").content("The capital of France is Paris.")); - final var message = - ChatMessage.create().role("user").content("What is the typical food there?"); - - final var config = - TEMPLATE_CONFIG - .apply(TemplatingModuleConfig.create().template(message)) - .messagesHistory(messagesHistory); - - final var result = client.orchestrationV1EndpointsCreate(config); - - assertThat(result.getRequestId()).isEqualTo("26ea36b5-c196-4806-a9a6-a686f0c6ad91"); - - // verify that the history is sent correctly - final String request = - new String( - getClass() - .getClassLoader() - .getResourceAsStream("messagesHistoryRequest.json") - .readAllBytes()); - verify(postRequestedFor(urlPathEqualTo("/completion")).withRequestBody(equalToJson(request))); - } -} diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatOptionsTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatOptionsTest.java new file mode 100644 index 00000000..a06b8e7d --- /dev/null +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatOptionsTest.java @@ -0,0 +1,55 @@ +package com.sap.ai.sdk.orchestration.spring; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig; +import java.util.Map; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class OrchestrationChatOptionsTest { + + OrchestrationChatOptions opts; + + @BeforeEach + void setUp() { + opts = new OrchestrationChatOptions(); + } + + @Test + void testFluentApi() { + assertThat(opts).isSameAs(opts.withLlmConfig(null)); + } + + @Test + void testHyperParameters() { + var llm = + LLMModuleConfig.create() + .modelName("foo") + .modelParams(Map.of("temperature", 0.5, "maxTokens", 100)); + opts.withLlmConfig(llm); + + assertThat(opts.getTemperature()).isEqualTo(0.5); + assertThat(opts.getMaxTokens()).isEqualTo(100); + } + + @Test + void testEqualsAndHashCode() { + var llm = + LLMModuleConfig.create() + .modelName("foo") + .modelParams(Map.of("temperature", 0.5, "maxTokens", 100)); + + var opts1 = + new OrchestrationChatOptions() + .withTemplateParameters(Map.of("foo", "bar")) + .withLlmConfig(llm); + var opts2 = + new OrchestrationChatOptions() + .withTemplateParameters(Map.of("foo", "bar")) + .withLlmConfig(llm); + + assertThat(opts1).isEqualTo(opts2); + assertThat(opts1.hashCode()).isEqualTo(opts2.hashCode()); + } +} diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatResponseTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatResponseTest.java new file mode 100644 index 00000000..5a188ca5 --- /dev/null +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/spring/OrchestrationChatResponseTest.java @@ -0,0 +1,54 @@ +package com.sap.ai.sdk.orchestration.spring; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.sap.ai.sdk.orchestration.client.model.ChatMessage; +import com.sap.ai.sdk.orchestration.client.model.LLMChoice; +import com.sap.ai.sdk.orchestration.client.model.LLMModuleResult; +import com.sap.ai.sdk.orchestration.client.model.TokenUsage; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.AssistantMessage; + +class OrchestrationChatResponseTest { + + @Test + void testToAssistantMessage() { + var choice = + LLMChoice.create() + .index(0) + .message(ChatMessage.create().role("assistant").content("Hello, world!")) + .finishReason("stop"); + + AssistantMessage message = OrchestrationChatResponse.toAssistantMessage(choice); + + assertThat(message.getContent()).isEqualTo("Hello, world!"); + assertThat(message.getMetadata()).containsEntry("finish_reason", "stop"); + assertThat(message.getMetadata()).containsEntry("index", 0); + } + + @Test + void testToChatResponseMetadata() { + var moduleResult = + LLMModuleResult.create() + .id("test-id") + ._object("test-object") + .created(123456789) + .model("test-model") + .choices(List.of()) + .usage(TokenUsage.create().completionTokens(20).promptTokens(10).totalTokens(30)); + + var metadata = OrchestrationChatResponse.toChatResponseMetadata(moduleResult); + + assertThat(metadata.getId()).isEqualTo("test-id"); + assertThat(metadata.getModel()).isEqualTo("test-model"); + assertThat(metadata.get("object")).isEqualTo("test-object"); + assertThat(metadata.get("created")).isEqualTo(123456789); + + var usage = metadata.getUsage(); + + assertThat(usage.getPromptTokens()).isEqualTo(10L); + assertThat(usage.getGenerationTokens()).isEqualTo(20L); + assertThat(usage.getTotalTokens()).isEqualTo(30L); + } +} diff --git a/orchestration/src/test/resources/errorResponse.json b/orchestration/src/test/resources/errorResponse.json new file mode 100644 index 00000000..6964cbc1 --- /dev/null +++ b/orchestration/src/test/resources/errorResponse.json @@ -0,0 +1,7 @@ +{ + "request_id": "59468e72-7309-4299-b988-bf3bbea461f8", + "code": 400, + "message": "'orchestration_config' is a required property", + "location": "request body", + "module_results": {} +} \ No newline at end of file diff --git a/orchestration/src/test/resources/filteringLooseRequest.json b/orchestration/src/test/resources/filteringLooseRequest.json deleted file mode 100644 index 40c1aa91..00000000 --- a/orchestration/src/test/resources/filteringLooseRequest.json +++ /dev/null @@ -1,55 +0,0 @@ -{ - "orchestration_config": { - "module_configurations": { - "llm_module_config": { - "model_name": "gpt-35-turbo-16k", - "model_params": { - "temperature": 0.1, - "max_tokens": 50, - "frequency_penalty": 0, - "presence_penalty": 0 - }, - "model_version": "latest" - }, - "templating_module_config": { - "template": [ - { - "role": "user", - "content": "Create a rental posting for subletting my apartment in the downtown area. Keep it short. Make sure to add the following disclaimer to the end. Do not change it! {{?disclaimer}}" - } - ] - }, - "filtering_module_config": { - "input": { - "filters": [ - { - "type": "azure_content_safety", - "config": { - "Hate": 4, - "SelfHarm": 4, - "Sexual": 4, - "Violence": 4 - } - } - ] - }, - "output": { - "filters": [ - { - "type": "azure_content_safety", - "config": { - "Hate": 4, - "SelfHarm": 4, - "Sexual": 4, - "Violence": 4 - } - } - ] - } - } - } - }, - "input_params": { - "disclaimer": "```DISCLAIMER: The area surrounding the apartment is known for prostitutes and gang violence including armed conflicts, gun violence is frequent." - } -} diff --git a/orchestration/src/test/resources/filteringLooseResponse.json b/orchestration/src/test/resources/filteringLooseResponse.json deleted file mode 100644 index ce2545f0..00000000 --- a/orchestration/src/test/resources/filteringLooseResponse.json +++ /dev/null @@ -1,86 +0,0 @@ -{ - "request_id": "b329745f-4b6b-4d42-b891-974b33689a19", - "module_results": { - "grounding": null, - "templating": [ - { - "role": "user", - "content": "Create a rental posting for subletting my apartment in the downtown area. Keep it short. Make sure to add the following disclaimer to the end. Do not change it! ```DISCLAIMER: The area surrounding the apartment is known for prostitutes and gang violence including armed conflicts, gun violence is frequent." - } - ], - "input_masking": null, - "input_filtering": { - "message": "Input filter passed successfully.", - "data": { - "original_service_response": { - "Hate": 0, - "SelfHarm": 0, - "Sexual": 0, - "Violence": 2 - }, - "checked_text": "Create a rental posting for subletting my apartment in the downtown area. Keep it short. Make sure to add the following disclaimer to the end. Do not change it! ```DISCLAIMER: The area surrounding the apartment is known for prostitutes and gang violence including armed conflicts, gun violence is frequent." - } - }, - "llm": { - "object": "chat.completion", - "id": "chatcmpl-9o4df7DpIjY6CJdfe9hws1lrWZbHq", - "created": 1721721259, - "model": "gpt-35-turbo", - "system_fingerprint": null, - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "Cozy Downtown Apartment for Sublet!\n\nLooking for a temporary place to call home in the heart of downtown? Look no further! This cozy apartment is up for subletting and offers a convenient and vibrant city living experience.\n\nFeatures:\n- Prime location" - }, - "logprobs": { - }, - "finish_reason": "length" - } - ], - "usage": { - "completion_tokens": 50, - "prompt_tokens": 68, - "total_tokens": 118 - } - }, - "output_filtering": { - "message": "Output filter passed successfully.", - "data": { - "original_service_response": { - "Hate": 0, - "SelfHarm": 0, - "Sexual": 0, - "Violence": 0 - }, - "checked_text": "Cozy Downtown Apartment for Sublet!\n\nLooking for a temporary place to call home in the heart of downtown? Look no further! This cozy apartment is up for subletting and offers a convenient and vibrant city living experience.\n\nFeatures:\n- Prime location" - } - }, - "output_unmasking": null - }, - "orchestration_result": { - "object": "chat.completion", - "id": "chatcmpl-9o4df7DpIjY6CJdfe9hws1lrWZbHq", - "created": 1721721259, - "model": "gpt-35-turbo", - "system_fingerprint": null, - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "Cozy Downtown Apartment for Sublet!\n\nLooking for a temporary place to call home in the heart of downtown? Look no further! This cozy apartment is up for subletting and offers a convenient and vibrant city living experience.\n\nFeatures:\n- Prime location" - }, - "logprobs": { - }, - "finish_reason": "length" - } - ], - "usage": { - "completion_tokens": 50, - "prompt_tokens": 68, - "total_tokens": 118 - } - } -} diff --git a/orchestration/src/test/resources/messagesHistoryRequest.json b/orchestration/src/test/resources/messagesHistoryRequest.json deleted file mode 100644 index 984e8d08..00000000 --- a/orchestration/src/test/resources/messagesHistoryRequest.json +++ /dev/null @@ -1,35 +0,0 @@ -{ - "orchestration_config": { - "module_configurations": { - "llm_module_config": { - "model_name": "gpt-35-turbo-16k", - "model_params": { - "presence_penalty": 0, - "frequency_penalty": 0, - "max_tokens": 50, - "temperature": 0.1 - }, - "model_version": "latest" - }, - "templating_module_config": { - "template": [ - { - "role": "user", - "content": "What is the typical food there?" - } - ] - } - } - }, - "input_params": {}, - "messages_history": [ - { - "role": "user", - "content": "What is the capital of France?" - }, - { - "role": "assistant", - "content": "The capital of France is Paris." - } - ] -} diff --git a/orchestration/src/test/resources/serializedRequest.json b/orchestration/src/test/resources/serializedRequest.json new file mode 100644 index 00000000..a4101d65 --- /dev/null +++ b/orchestration/src/test/resources/serializedRequest.json @@ -0,0 +1,60 @@ +{ + "orchestration_config" : { + "module_configurations" : { + "llm_module_config" : { + "model_name" : "gpt-35-turbo-16k", + "model_params" : { + "frequency_penalty" : 1, + "temperature" : 0.5 + }, + "model_version" : "latest" + }, + "templating_module_config" : { + "template" : [ { + "role" : "user", + "content" : "{{?input}}" + } ], + "defaults" : { + "input" : "Hello World!" + } + }, + "filtering_module_config" : { + "input" : { + "filters" : [ { + "type" : "azure_content_safety", + "config" : { + "SelfHarm" : 6 + } + } ] + }, + "output" : { + "filters" : [ { + "type" : "azure_content_safety", + "config" : { + "Hate" : 0, + "SelfHarm" : 2, + "Sexual": 4, + "Violence": 6 + } + } ] + } + }, + "masking_module_config" : { + "masking_providers" : [ { + "type" : "sap_data_privacy_integration", + "method" : "anonymization", + "entities" : [ { + "type" : "profile-address" + }, { + "type" : "profile-iban" + }, { + "type" : "profile-location" + } ] + } ] + } + } + }, + "input_params" : { + "input" : "Reply with 'Orchestration Service is working!' in German" + } +} \ No newline at end of file diff --git a/orchestration/src/test/resources/templatingResponse.json b/orchestration/src/test/resources/serializedResponse.json similarity index 59% rename from orchestration/src/test/resources/templatingResponse.json rename to orchestration/src/test/resources/serializedResponse.json index 3fcc0a30..d471d2f5 100644 --- a/orchestration/src/test/resources/templatingResponse.json +++ b/orchestration/src/test/resources/serializedResponse.json @@ -4,7 +4,7 @@ "templating": [ { "role": "user", - "content": "Reply with 'Orchestration Service is working!' in German" + "content": "Hello there!" } ], "llm": { @@ -17,7 +17,7 @@ "index": 0, "message": { "role": "assistant", - "content": "Orchestration Service funktioniert!" + "content": "General Kenobi!" }, "finish_reason": "stop" } @@ -27,6 +27,30 @@ "prompt_tokens": 19, "total_tokens": 26 } + }, + "input_filtering": { + "message": "Input filter passed successfully.", + "data": { + "original_service_response": { + "Hate": 0, + "SelfHarm": 0, + "Sexual": 0, + "Violence": 2 + }, + "checked_text": "Hello there!" + } + }, + "output_filtering": { + "message": "Output filter passed successfully.", + "data": { + "original_service_response": { + "Hate": 0, + "SelfHarm": 0, + "Sexual": 0, + "Violence": 2 + }, + "checked_text": "General Kenobi!" + } } }, "orchestration_result": { @@ -39,7 +63,7 @@ "index": 0, "message": { "role": "assistant", - "content": "Orchestration Service funktioniert!" + "content": "General Kenobi!" }, "finish_reason": "stop" } diff --git a/orchestration/src/test/resources/templatingRequest.json b/orchestration/src/test/resources/templatingRequest.json deleted file mode 100644 index 3cde90a4..00000000 --- a/orchestration/src/test/resources/templatingRequest.json +++ /dev/null @@ -1,27 +0,0 @@ -{ - "orchestration_config": { - "module_configurations": { - "templating_module_config": { - "template": [ - { - "role": "user", - "content": "{{?input}}" - } - ] - }, - "llm_module_config": { - "model_name": "gpt-35-turbo-16k", - "model_params": { - "max_tokens": 50, - "temperature": 0.1, - "frequency_penalty": 0, - "presence_penalty": 0 - }, - "model_version": "latest" - } - } - }, - "input_params": { - "input": "Reply with 'Orchestration Service is working!' in German" - } -} diff --git a/pom.xml b/pom.xml index 04528a89..c33d8ca8 100644 --- a/pom.xml +++ b/pom.xml @@ -33,6 +33,7 @@ orchestration foundation-models/openai sample-code/spring-app + sample-code/spring-ai-app @@ -46,17 +47,21 @@ 17 17 UTF-8 + 5.12.0 + 6.1.13 + 3.3.3 + 1.0.0-SNAPSHOT + 2.0.16 + 5.11.0 3.9.1 3.26.3 - 2.0.16 2.43.0 10.18.1 2.1.3 3.5.0 2.1.3 - 6.1.13 5.13.0 false @@ -74,16 +79,30 @@ pom import - org.springframework - spring-core - ${springframework.version} + spring-framework-bom + ${spring.version} + pom + import - org.springframework - spring-web - ${springframework.version} + org.springframework.ai + spring-ai-bom + ${spring.ai.version} + pom + import + + + + org.springframework.boot + spring-boot + ${spring.boot.version} + + + org.springframework.boot + spring-boot-autoconfigure + ${spring.boot.version} @@ -146,6 +165,26 @@ test + + + + + true + + spring-milestones + Spring Milestones + https://repo.spring.io/milestone + + + + false + + spring-snapshots + Spring Snapshots + https://repo.spring.io/snapshot + + + @@ -165,7 +204,7 @@ - + @@ -269,7 +308,8 @@ - Do not use JUnit 4.X, instead exclusively use Junit Jupiter. JUnit 4 tests will (silently) simply not be executed. + Do not use JUnit 4.X, instead exclusively use Junit Jupiter. JUnit 4 tests + will (silently) simply not be executed. junit:junit @@ -387,7 +427,7 @@ .pipeline/pmd.xml .pipeline/sdk_specific_pmd_rules.xml + https://gitbox.apache.org/repos/asf?p=maven-pmd-plugin.git;a=blob_plain;f=src/main/resources/rulesets/java/maven-pmd-plugin-default.xml;hb=HEAD --> /rulesets/java/maven-pmd-plugin-default.xml false diff --git a/sample-code/spring-ai-app/README.md b/sample-code/spring-ai-app/README.md new file mode 100644 index 00000000..6704991d --- /dev/null +++ b/sample-code/spring-ai-app/README.md @@ -0,0 +1,20 @@ +# Sample Code and E2E Test + +![e2e-test](https://github.com/SAP/ai-sdk-java/actions/workflows/e2e-test.yaml/badge.svg) + +Sample code to demonstrate the usage of the SAP AI SDK. +Also used as basis for running E2E tests. + +## Build, Run, Deploy Locally + +Build the project with: + +1. `mvn compile` +2. [Download a service key for your AI Core service instance](../../README.md#set-credentials-as-dedicated-environment-variable) +3. Create the environment variable `AICORE_SERVICE_KEY` +4. Run the application with `mvn spring-boot:run` +5. [See all available endpoints](localhost:8080) + +## Run the E2E Test + +Trigger the [GitHub Action](https://github.com/SAP/ai-sdk-java/actions/workflows/e2e-test.yml). diff --git a/sample-code/spring-ai-app/pom.xml b/sample-code/spring-ai-app/pom.xml new file mode 100644 index 00000000..f57193a9 --- /dev/null +++ b/sample-code/spring-ai-app/pom.xml @@ -0,0 +1,140 @@ + + + 4.0.0 + + com.sap.ai.sdk + sdk-parent + 0.1.0-SNAPSHOT + ../../pom.xml + + com.sap.ai.sdk.app + spring-ai-app + Spring AI Sample Application + Spring Boot application that tests AI clients when deployed in BTP. + https://github.com/SAP/ai-sdk-java?tab=readme-ov-file#documentation + + SAP SE + https://www.sap.com + + + + The Apache Software License, Version 2.0 + https://www.apache.org/licenses/LICENSE-2.0.txt + + + + + SAP + cloudsdk@sap.com + SAP SE + https://www.sap.com + + + + + 1.5.8 + + true + + true + true + + + + + + com.sap.ai.sdk + orchestration + + + com.sap.cloud.sdk.cloudplatform + cloudplatform-core + + + org.springframework.ai + spring-ai-core + + + org.springframework.boot + spring-boot-starter-web + ${spring.boot.version} + + + org.springframework + spring-core + + + org.springframework + spring-web + + + org.springframework + spring-context + + + org.springframework + spring-beans + + + com.google.code.findbugs + jsr305 + + + org.slf4j + slf4j-api + + + + ch.qos.logback + logback-classic + ${logback.version} + runtime + + + + org.projectlombok + lombok + provided + + + + + + + ${project.artifactId} + + + org.springframework.boot + spring-boot-maven-plugin + 3.3.3 + + false + + + + + repackage + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + ${skipTests} + + + + + diff --git a/sample-code/spring-ai-app/src/main/java/com/sap/ai/sdk/app/Application.java b/sample-code/spring-ai-app/src/main/java/com/sap/ai/sdk/app/Application.java new file mode 100644 index 00000000..e760bdcf --- /dev/null +++ b/sample-code/spring-ai-app/src/main/java/com/sap/ai/sdk/app/Application.java @@ -0,0 +1,20 @@ +package com.sap.ai.sdk.app; + +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.boot.web.servlet.ServletComponentScan; +import org.springframework.context.annotation.ComponentScan; + +@SpringBootApplication +@ComponentScan({"com.sap.cloud.sdk", "com.sap.ai.sdk"}) +@ServletComponentScan({"com.sap.cloud.sdk", "com.sap.ai.sdk"}) +class Application { + /** + * Main method to start the Spring Boot application. + * + * @param args Command line arguments. + */ + public static void main(final String[] args) { + SpringApplication.run(Application.class, args); + } +} diff --git a/sample-code/spring-ai-app/src/main/java/com/sap/ai/sdk/app/AsynchronousConfiguration.java b/sample-code/spring-ai-app/src/main/java/com/sap/ai/sdk/app/AsynchronousConfiguration.java new file mode 100644 index 00000000..5549ccb6 --- /dev/null +++ b/sample-code/spring-ai-app/src/main/java/com/sap/ai/sdk/app/AsynchronousConfiguration.java @@ -0,0 +1,26 @@ +package com.sap.ai.sdk.app; + +// Spring Boot asynchronous configuration + +import com.sap.cloud.sdk.cloudplatform.thread.ThreadContextExecutors; +import java.util.concurrent.Executor; +import javax.annotation.Nonnull; +import org.springframework.context.annotation.Configuration; +import org.springframework.scheduling.annotation.Async; +import org.springframework.scheduling.annotation.AsyncConfigurer; +import org.springframework.scheduling.annotation.EnableAsync; + +/** + * Implementation that customizes the {@link Executor} instance used when processing @{@link Async} + * method invocations. + */ +@SuppressWarnings("unused") // Used by Spring +@EnableAsync +@Configuration +public class AsynchronousConfiguration implements AsyncConfigurer { + @Nonnull + @Override + public Executor getAsyncExecutor() { + return ThreadContextExecutors.getExecutor(); + } +} diff --git a/sample-code/spring-ai-app/src/main/java/com/sap/ai/sdk/app/ControllerExceptionHandler.java b/sample-code/spring-ai-app/src/main/java/com/sap/ai/sdk/app/ControllerExceptionHandler.java new file mode 100644 index 00000000..a9263593 --- /dev/null +++ b/sample-code/spring-ai-app/src/main/java/com/sap/ai/sdk/app/ControllerExceptionHandler.java @@ -0,0 +1,30 @@ +package com.sap.ai.sdk.app; + +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.web.bind.annotation.ControllerAdvice; +import org.springframework.web.bind.annotation.ExceptionHandler; + +@ControllerAdvice +@SuppressWarnings("unused") // Used by Spring +class ControllerExceptionHandler { + /** Exceptions thrown by the Spring Boot controllers are turned into a readable text response. */ + @ExceptionHandler(Exception.class) + ResponseEntity handleError(final Exception ex) { + final var headers = new HttpHeaders(); + headers.setContentType(MediaType.TEXT_PLAIN); + + final var txt = + ex.getMessage() + + "\n\n" + + Stream.of(ex.getStackTrace()).map(Object::toString).collect(Collectors.joining("\n")) + + "\n\n" + + Stream.of(ex.getSuppressed()).map(Object::toString).collect(Collectors.joining("\n")); + + return new ResponseEntity<>(txt, headers, HttpStatus.INTERNAL_SERVER_ERROR); + } +} diff --git a/sample-code/spring-ai-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java b/sample-code/spring-ai-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java new file mode 100644 index 00000000..0f638d82 --- /dev/null +++ b/sample-code/spring-ai-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java @@ -0,0 +1,84 @@ +package com.sap.ai.sdk.app.controllers; + +import com.sap.ai.sdk.orchestration.DpiMaskingConfig; +import com.sap.ai.sdk.orchestration.OrchestrationClient; +import com.sap.ai.sdk.orchestration.client.model.DPIEntities; +import com.sap.ai.sdk.orchestration.spring.OrchestrationChatModel; +import com.sap.ai.sdk.orchestration.spring.OrchestrationChatOptions; +import java.util.List; +import java.util.Map; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor; +import org.springframework.ai.chat.memory.InMemoryChatMemory; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +/** Endpoints for the Orchestration service */ +@RestController +@RequestMapping("/orchestration") +class OrchestrationController { + + // uses defaults from application.yaml + @Autowired OrchestrationChatModel client; + + OrchestrationChatModel clientWithoutDefaults = + new OrchestrationChatModel(new OrchestrationClient()); + + @GetMapping("/completion") + ChatResponse completion() { + var prompt = new Prompt("What is the capital of France?"); + + return client.call(prompt); + } + + @GetMapping("/template/local") + ChatResponse templateLocal() { + var template = new PromptTemplate("input"); + var prompt = template.create(Map.of("input", "Hello World!")); + + return client.call(prompt); + } + + @GetMapping("/template/remote") + ChatResponse templateRemote() { + List messages = List.of(new UserMessage("{{?input}}")); + var opts = + new OrchestrationChatOptions() + .withTemplate(messages) + .withTemplateParameters(Map.of("input", "Hello World!")); + var prompt = new Prompt(List.of(), opts); + + return client.call(prompt); + } + + @GetMapping("/masking") + ChatResponse masking() { + var masking = DpiMaskingConfig.anonymization().withEntities(DPIEntities.EMAIL, DPIEntities.ADDRESS, DPIEntities.LOCATION); + + var opts = new OrchestrationChatOptions().withMaskingConfig(masking); + var prompt = + new Prompt( + "Please write 'Hello World!' to me via email. My email address is foo.bar@baz.ai", + opts); + + return client.call(prompt); + } + + @GetMapping("/chatMemory") + ChatResponse chatMemory() { + var memory = new InMemoryChatMemory(); + var advisor = new MessageChatMemoryAdvisor(memory); + ChatClient cl = ChatClient.builder(client).defaultAdvisors(advisor).build(); + var prompt = new Prompt("Hello World!"); + + cl.prompt(prompt).call(); + return cl.prompt(prompt).call().chatResponse(); + } +} diff --git a/sample-code/spring-ai-app/src/main/resources/application.yaml b/sample-code/spring-ai-app/src/main/resources/application.yaml new file mode 100644 index 00000000..234ffe77 --- /dev/null +++ b/sample-code/spring-ai-app/src/main/resources/application.yaml @@ -0,0 +1,5 @@ +com.sap.ai.sdk: + orchestration: + llm: + modelName: gpt-35-turbo + modelVersion: latest \ No newline at end of file diff --git a/sample-code/spring-ai-app/src/main/resources/logback-spring.xml b/sample-code/spring-ai-app/src/main/resources/logback-spring.xml new file mode 100644 index 00000000..b1d41391 --- /dev/null +++ b/sample-code/spring-ai-app/src/main/resources/logback-spring.xml @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + + + + + + diff --git a/sample-code/spring-ai-app/src/main/resources/static/index.html b/sample-code/spring-ai-app/src/main/resources/static/index.html new file mode 100644 index 00000000..b7a2d904 --- /dev/null +++ b/sample-code/spring-ai-app/src/main/resources/static/index.html @@ -0,0 +1,85 @@ + + + AI SDK test app + + + +

Welcome to the end-to-end test application!

+ + + + + +

Endpoints

+ + + + diff --git a/sample-code/spring-app/pom.xml b/sample-code/spring-app/pom.xml index ce21ea15..6ce01712 100644 --- a/sample-code/spring-app/pom.xml +++ b/sample-code/spring-app/pom.xml @@ -84,7 +84,7 @@ org.springframework spring-context - ${springframework.version} + ${spring.version} org.springframework @@ -97,7 +97,7 @@ org.springframework spring-webmvc - ${springframework.version} + ${spring.version} com.google.code.findbugs diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java index 77f47531..8b96e689 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/OrchestrationController.java @@ -1,27 +1,18 @@ package com.sap.ai.sdk.app.controllers; -import static com.sap.ai.sdk.core.Core.getOrchestrationClient; - -import com.sap.ai.sdk.orchestration.client.OrchestrationCompletionApi; -import com.sap.ai.sdk.orchestration.client.model.AzureContentSafety; -import com.sap.ai.sdk.orchestration.client.model.AzureThreshold; +import com.sap.ai.sdk.orchestration.AzureContentFilter; +import com.sap.ai.sdk.orchestration.AzureContentFilter.Sensitivity; +import com.sap.ai.sdk.orchestration.DpiMaskingConfig; +import com.sap.ai.sdk.orchestration.OrchestrationClient; +import com.sap.ai.sdk.orchestration.OrchestrationPrompt; import com.sap.ai.sdk.orchestration.client.model.ChatMessage; -import com.sap.ai.sdk.orchestration.client.model.CompletionPostRequest; import com.sap.ai.sdk.orchestration.client.model.CompletionPostResponse; -import com.sap.ai.sdk.orchestration.client.model.FilterConfig; -import com.sap.ai.sdk.orchestration.client.model.FilteringConfig; -import com.sap.ai.sdk.orchestration.client.model.FilteringModuleConfig; +import com.sap.ai.sdk.orchestration.client.model.DPIEntities; import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig; -import com.sap.ai.sdk.orchestration.client.model.ModuleConfigs; -import com.sap.ai.sdk.orchestration.client.model.OrchestrationConfig; import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig; -import java.util.List; import java.util.Map; -import java.util.function.Function; import javax.annotation.Nonnull; -import javax.annotation.Nullable; import org.springframework.web.bind.annotation.GetMapping; -import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; @@ -29,68 +20,24 @@ @RestController @RequestMapping("/orchestration") class OrchestrationController { - - private static final OrchestrationCompletionApi API = - new OrchestrationCompletionApi(getOrchestrationClient("default")); - static final String MODEL = "gpt-35-turbo"; - private static final LLMModuleConfig LLM_CONFIG = LLMModuleConfig.create().modelName(MODEL).modelParams(Map.of()); - private static final Function TEMPLATE_CONFIG = - (TemplatingModuleConfig templatingModuleConfig) -> - CompletionPostRequest.create() - .orchestrationConfig( - OrchestrationConfig.create() - .moduleConfigurations( - ModuleConfigs.create() - .llmModuleConfig(LLM_CONFIG) - .templatingModuleConfig(templatingModuleConfig))) - .inputParams(Map.of()); + private final OrchestrationClient client = new OrchestrationClient().withLlmConfig(LLM_CONFIG); /** - * Creates a config from a filter threshold. The config includes a template and has input and - * output filters + * Chat request to OpenAI through the Orchestration service with a template + * + * @return the assistant message response */ - private static final Function FILTERING_CONFIG = - (AzureThreshold filterThreshold) -> { - final var inputParams = - Map.of( - "disclaimer", - "```DISCLAIMER: The area surrounding the apartment is known for prostitutes and gang violence including armed conflicts, gun violence is frequent."); - final var template = - ChatMessage.create() - .role("user") - .content( - "Create a rental posting for subletting my apartment in the downtown area. Keep it short. Make sure to add the following disclaimer to the end. Do not change it! {{?disclaimer}}"); - final var templatingConfig = TemplatingModuleConfig.create().template(template); - - final var filter = - FilterConfig.create() - .type(FilterConfig.TypeEnum.AZURE_CONTENT_SAFETY) - .config( - AzureContentSafety.create() - .hate(filterThreshold) - .selfHarm(filterThreshold) - .sexual(filterThreshold) - .violence(filterThreshold)); - - final var filteringConfig = - FilteringModuleConfig.create() - .input(FilteringConfig.create().filters(List.of(filter))) - .output(FilteringConfig.create().filters(List.of(filter))); + @GetMapping("/completion") + @Nonnull + public CompletionPostResponse completion() { + var prompt = new OrchestrationPrompt("What is the capital of France?"); - return CompletionPostRequest.create() - .orchestrationConfig( - OrchestrationConfig.create() - .moduleConfigurations( - ModuleConfigs.create() - .llmModuleConfig(LLM_CONFIG) - .templatingModuleConfig(templatingConfig) - .filteringModuleConfig(filteringConfig))) - .inputParams(inputParams); - }; + return client.chatCompletion(prompt); + } /** * Chat request to OpenAI through the Orchestration service with a template @@ -98,58 +45,39 @@ class OrchestrationController { * @return the assistant message response */ @GetMapping("/template") - @Nullable + @Nonnull public CompletionPostResponse template() { + var template = ChatMessage.create().role("user").content("Reply with 'The Orchestration Service is working!' in {{?language}}"); + var inputParams = Map.of("language", "german"); - final var template = ChatMessage.create().role("user").content("{{?input}}"); - final var inputParams = - Map.of("input", "Reply with 'Orchestration Service is working!' in German"); - - final var config = - TEMPLATE_CONFIG - .apply(TemplatingModuleConfig.create().template(template)) - .inputParams(inputParams); + var prompt = + new OrchestrationPrompt(inputParams) + .withTemplate(TemplatingModuleConfig.create().template(template)); - return API.orchestrationV1EndpointsCreate(config); + return client.chatCompletion(prompt); } - /** - * Chat request to OpenAI through the Orchestration service with a violent template and both input - * and output filters. - * - * @param threshold A high threshold is a loose filter, a low threshold is a strict filter - * @return the assistant message response - */ - @GetMapping("/filter/{threshold}") - @Nullable - public CompletionPostResponse filter(@Nonnull @PathVariable("threshold") final String threshold) { - - final var config = - FILTERING_CONFIG.apply(AzureThreshold.fromValue(Integer.parseInt(threshold))); + @GetMapping("/filter/{level}") + @Nonnull + public CompletionPostResponse filter(@Nonnull Sensitivity level) { + var filter = new AzureContentFilter().hate(level); + var prompt = new OrchestrationPrompt("This prompt demonstrates how to hit the fucking input filter. And hit it hard, like we mean it.") + .withInputContentFilter(filter); - return API.orchestrationV1EndpointsCreate(config); + // if the level is strict, this will throw, if not it will return a result + return client.chatCompletion(prompt); } - /** - * Chat request to OpenAI through the Orchestration service with a template - * - * @return the assistant message response - */ - @GetMapping("/messagesHistory") - @Nullable - public CompletionPostResponse messagesHistory() { - final List messagesHistory = - List.of( - ChatMessage.create().role("user").content("What is the capital of France?"), - ChatMessage.create().role("assistant").content("The capital of France is Paris.")); - final var message = - ChatMessage.create().role("user").content("What is the typical food there?"); + @GetMapping("/masking") + @Nonnull + public CompletionPostResponse masking() { + var masking = DpiMaskingConfig.pseudonymization().withEntities(DPIEntities.EMAIL, DPIEntities.LOCATION); - final var config = - TEMPLATE_CONFIG - .apply(TemplatingModuleConfig.create().template(message)) - .messagesHistory(messagesHistory); + var prompt = + new OrchestrationPrompt( + "Please write 'Hello World!' to me via email. My email address is foo.bar@baz.ai") + .withMaskingConfig(masking); - return API.orchestrationV1EndpointsCreate(config); + return client.chatCompletion(prompt); } } diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java index fb716a2d..667920e4 100644 --- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java +++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/OrchestrationTest.java @@ -1,20 +1,27 @@ package com.sap.ai.sdk.app.controllers; +import static com.sap.ai.sdk.orchestration.AzureContentFilter.Sensitivity.HIGH; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import com.sap.ai.sdk.orchestration.AzureContentFilter; +import com.sap.ai.sdk.orchestration.OrchestrationClientException; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.springframework.web.client.HttpClientErrorException; class OrchestrationTest { + OrchestrationController controller; + + @BeforeEach + void setUp() { + controller = new OrchestrationController(); + } + @Test - void template() { - final var result = new OrchestrationController().template(); + void testCompletion() { + final var result = controller.completion(); assertThat(result.getRequestId()).isNotEmpty(); - assertThat(result.getModuleResults().getTemplating().get(0).getContent()) - .isEqualTo("Reply with 'Orchestration Service is working!' in German"); - assertThat(result.getModuleResults().getTemplating().get(0).getRole()).isEqualTo("user"); var llm = result.getModuleResults().getLlm(); assertThat(llm.getId()).isNotEmpty(); assertThat(llm.getObject()).isEqualTo("chat.completion"); @@ -44,19 +51,29 @@ void template() { } @Test - void looseFilter() { - assertThat(new OrchestrationController().filter("4")).isNotNull(); + void testTemplate() { + var result = controller.template(); + + var templateResult = result.getModuleResults().getTemplating().get(0); + assertThat(templateResult.getContent()) + .isEqualTo("Reply with 'The Orchestration Service is working!' in german"); + assertThat(templateResult.getRole()).isEqualTo("user"); } @Test - void strictFilter() { - assertThatThrownBy(() -> new OrchestrationController().filter("0")) - .isInstanceOf(HttpClientErrorException.class) - .hasMessageContaining("400 Bad Request"); + void testLenientContentFilter() { + var result = controller.filter(AzureContentFilter.Sensitivity.LOW); + + var filterResult = result.getModuleResults().getInputFiltering(); + + assertThat(filterResult.getMessage()).contains("passed"); } @Test - void messagesHistory() { - assertThat(new OrchestrationController()).isNotNull(); + void testStrictContentFilter() { + assertThatThrownBy(() -> new OrchestrationController().filter(HIGH)) + .isInstanceOf(OrchestrationClientException.class) + .hasMessageContaining("400 Bad Request") + .hasMessageContaining("Content filtered"); } }