Skip to content

Commit

Permalink
Minor stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
MatKuhr committed Oct 15, 2024
1 parent b8414f2 commit bcc3a1a
Show file tree
Hide file tree
Showing 23 changed files with 636 additions and 489 deletions.
Original file line number Diff line number Diff line change
@@ -1,26 +1,21 @@
package com.sap.ai.sdk.orchestration;

import static com.sap.ai.sdk.orchestration.client.model.AzureThreshold.fromValue;

import com.sap.ai.sdk.orchestration.client.model.AzureContentSafety;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import lombok.Data;
import lombok.RequiredArgsConstructor;
import lombok.experimental.Accessors;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import static com.sap.ai.sdk.orchestration.client.model.AzureThreshold.fromValue;

@Data
@Accessors(fluent = true)
public class AzureContentFilter implements ContentFilter {
@Nullable
private Setting hate;
@Nullable
private Setting selfHarm;
@Nullable
private Setting sexual;
@Nullable
private Setting violence;
@Nullable private Setting hate;
@Nullable private Setting selfHarm;
@Nullable private Setting sexual;
@Nullable private Setting violence;

@RequiredArgsConstructor
public enum Setting {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig;
import io.vavr.control.Option;

import javax.annotation.Nonnull;
import lombok.AccessLevel;
import lombok.Data;
import lombok.RequiredArgsConstructor;
import lombok.Setter;

@Data
@Setter(AccessLevel.PRIVATE)
@RequiredArgsConstructor(access = AccessLevel.PRIVATE)
public class DefaultOrchestrationConfig<T extends OrchestrationConfig<T>>
implements OrchestrationConfig<T> {

Expand All @@ -20,45 +21,98 @@ public class DefaultOrchestrationConfig<T extends OrchestrationConfig<T>>
@Nonnull private Option<ContentFilter> inputContentFilter = Option.none();
@Nonnull private Option<ContentFilter> outputContentFilter = Option.none();

@Nonnull
@Override
@Nonnull private final T wrapper;

@SuppressWarnings("unchecked")
public T instance() {
return (T) this;
private DefaultOrchestrationConfig() {
wrapper = (T) this;
}

/**
* Create a new instance of {@link DefaultOrchestrationConfig} to delegate to. This is useful when
* exposing the {@link OrchestrationConfig} in other objects, without re-implementing it. To
* maintain fluent API usage, the given wrapper object will be returned by the fluent methods,
* instead of this instance.
*
* @param wrapper The wrapper that delegates to this object.
* @param <T> The type of the wrapper object.
* @return The new instance.
* @see #standalone()
*/
@Nonnull
public static <T extends OrchestrationConfig<T>> DefaultOrchestrationConfig<T> asDelegateFor(
@Nonnull final T wrapper) {
return new DefaultOrchestrationConfig<>(wrapper);
}

/**
* Create an implementation without any object delegating to it. The fluent API will return this
* object itself.
*
* @return The new instance.
* @see #asDelegateFor(OrchestrationConfig)
*/
@Nonnull
public static DefaultOrchestrationConfig<?> standalone() {
return new DefaultOrchestrationConfig<>();
}

@Nonnull
@Override
public T withLlmConfig(@Nonnull final LLMModuleConfig llm) {
this.llmConfig = Option.some(llm);
return instance();
return wrapper;
}

@Nonnull
@Override
public T withTemplate(@Nonnull final TemplatingModuleConfig template) {
this.template = Option.some(template);
return instance();
return wrapper;
}

@Nonnull
@Override
public T withMaskingConfig(@Nonnull final MaskingConfig maskingConfig) {
this.maskingConfig = Option.some(maskingConfig);
return instance();
return wrapper;
}

@Nonnull
@Override
public T withInputContentFilter(@Nonnull final ContentFilter filter) {
this.inputContentFilter = Option.some(filter);
return instance();
return wrapper;
}

@Nonnull
@Override
public T withOutputContentFilter(@Nonnull final ContentFilter filter) {
this.outputContentFilter = Option.some(filter);
return instance();
return wrapper;
}

/**
* Copy the configuration into the given target configuration. The copy is
* <strong>shallow</strong> and does <strong>not override</strong> any existing configuration.
*
* <p>This has two main use cases:
*
* <ol>
* <li>Duplicating a config
* <li>Applying defaults to a config
* </ol>
*
* @param source The source configuration to copy from.
*/
public DefaultOrchestrationConfig<T> copyFrom(@Nonnull final OrchestrationConfig<?> source) {
llmConfig.orElse(source::getLlmConfig).forEach(this::withLlmConfig);
template.orElse(source::getTemplate).forEach(this::withTemplate);
maskingConfig.orElse(source::getMaskingConfig).forEach(this::withMaskingConfig);
inputContentFilter.orElse(source::getInputContentFilter).forEach(this::withInputContentFilter);
outputContentFilter
.orElse(source::getOutputContentFilter)
.forEach(this::withOutputContentFilter);
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import com.sap.ai.sdk.orchestration.client.model.DPIEntities;
import com.sap.ai.sdk.orchestration.client.model.DPIEntityConfig;
import com.sap.ai.sdk.orchestration.client.model.MaskingModuleConfig;
import com.sap.ai.sdk.orchestration.client.model.MaskingProviderConfig;
import java.util.List;
import javax.annotation.Nonnull;
Expand All @@ -33,11 +32,10 @@ public static Builder forPseudonymization() {
@Nonnull
MaskingProviderConfig toMaskingProviderDTO() {
var entities = this.entities.stream().map(it -> DPIEntityConfig.create().type(it)).toList();
return
MaskingProviderConfig.create()
.type(SAP_DATA_PRIVACY_INTEGRATION)
.method(maskingMethod)
.entities(entities);
return MaskingProviderConfig.create()
.type(SAP_DATA_PRIVACY_INTEGRATION)
.method(maskingMethod)
.entities(entities);
}

@RequiredArgsConstructor(access = AccessLevel.PRIVATE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,18 @@
import com.sap.ai.sdk.orchestration.client.model.ModuleConfigs;
import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig;
import io.vavr.control.Option;

import javax.annotation.Nonnull;
import java.util.ArrayList;
import java.util.List;
import javax.annotation.Nonnull;

class ModuleConfigFactory {
@Nonnull
static ModuleConfigs toModuleConfigDTO(
@Nonnull final OrchestrationConfig<?> config, @Nonnull final List<ChatMessage> messages) {
@Nonnull final OrchestrationConfig<?> config, @Nonnull final List<ChatMessage> messages) {
LLMModuleConfig llm =
config
.getLlmConfig()
.getOrElseThrow(() -> new IllegalStateException("LLM module config is required"));
config
.getLlmConfig()
.getOrElseThrow(() -> new IllegalStateException("LLM module config is required"));

/*
* Currently, we have to merge the prompt into the template configuration.
Expand All @@ -35,37 +34,39 @@ static ModuleConfigs toModuleConfigDTO(

if (messagesWithPrompt.isEmpty()) {
throw new IllegalStateException(
"A prompt is required. Pass at least one message or configure the templating module.");
"A prompt is required. Pass at least one message or configure the templating module.");
}
TemplatingModuleConfig template = TemplatingModuleConfig.create().template(messagesWithPrompt);
maybeTemplate.map(TemplatingModuleConfig::getDefaults).forEach(template::defaults);

ModuleConfigs dto =
ModuleConfigs.create().llmModuleConfig(llm).templatingModuleConfig(template);
ModuleConfigs.create().llmModuleConfig(llm).templatingModuleConfig(template);

config
.getMaskingConfig()
.filter(DpiMaskingConfig.class::isInstance)
.map(DpiMaskingConfig.class::cast)
.map(DpiMaskingConfig::toMaskingProviderDTO)
.map(it -> MaskingModuleConfig.create().maskingProviders(it))
.forEach(dto::maskingModuleConfig);
.getMaskingConfig()
.filter(DpiMaskingConfig.class::isInstance)
.map(DpiMaskingConfig.class::cast)
.map(DpiMaskingConfig::toMaskingProviderDTO)
.map(it -> MaskingModuleConfig.create().maskingProviders(it))
.forEach(dto::maskingModuleConfig);

var maybeInputFilter = config.getInputContentFilter();
var maybeOutputFilter = config.getOutputContentFilter();

if (maybeInputFilter.isDefined() || maybeOutputFilter.isDefined()) {
var filter = FilteringModuleConfig.create();
maybeInputFilter
.filter(AzureContentFilter.class::isInstance)
.map(AzureContentFilter.class::cast)
.map(AzureContentFilter::toFilterConfigDTO)
.map(it -> FilteringConfig.create().filters(it)).forEach(filter::input);
.filter(AzureContentFilter.class::isInstance)
.map(AzureContentFilter.class::cast)
.map(AzureContentFilter::toFilterConfigDTO)
.map(it -> FilteringConfig.create().filters(it))
.forEach(filter::input);
maybeOutputFilter
.filter(AzureContentFilter.class::isInstance)
.map(AzureContentFilter.class::cast)
.map(AzureContentFilter::toFilterConfigDTO)
.map(it -> FilteringConfig.create().filters(it)).forEach(filter::output);
.filter(AzureContentFilter.class::isInstance)
.map(AzureContentFilter.class::cast)
.map(AzureContentFilter::toFilterConfigDTO)
.map(it -> FilteringConfig.create().filters(it))
.forEach(filter::output);
dto = dto.filteringModuleConfig(filter);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public class OrchestrationClient implements OrchestrationConfig<OrchestrationCli

@Delegate @Nonnull
private final DefaultOrchestrationConfig<OrchestrationClient> clientConfig =
new DefaultOrchestrationConfig<>();
DefaultOrchestrationConfig.asDelegateFor(this);

@Nonnull private final HttpDestination destination;

Expand All @@ -47,12 +47,6 @@ public OrchestrationClient() {
this.destination = Core.getDestinationForDeployment("db1d64d9f06be467", "default").asHttp();
}

@Nonnull
@Override
public OrchestrationClient instance() {
return this;
}

/**
* Generate a completion for the given user prompt.
*
Expand All @@ -61,11 +55,15 @@ public OrchestrationClient instance() {
* @throws OrchestrationClientException if the request fails
*/
@Nonnull
// TODO: decide if we want to offer this in addition to the already simple `new Prompt("asdf")`
public String chatCompletion(@Nonnull final String userPrompt)
throws OrchestrationClientException {
var response = chatCompletion(new OrchestrationPrompt(userPrompt));
return response.getOrchestrationResult().getChoices().get(0).getMessage().getContent();

var choice = response.getOrchestrationResult().getChoices().get(0);
if (choice.getFinishReason().equalsIgnoreCase("content_filter")) {
throw new OrchestrationClientException("Output content filter triggered");
}
return choice.getMessage().getContent();
}

/**
Expand All @@ -78,14 +76,7 @@ public String chatCompletion(@Nonnull final String userPrompt)
@Nonnull
public CompletionPostResponse chatCompletion(@Nonnull final OrchestrationPrompt prompt)
throws OrchestrationClientException {
var moduleConfigsDto = prompt.toModuleConfigDTO(clientConfig);
var dto =
CompletionPostRequest.create()
.orchestrationConfig(
com.sap.ai.sdk.orchestration.client.model.OrchestrationConfig.create()
.moduleConfigurations(moduleConfigsDto))
.messagesHistory(prompt.getMessages())
.inputParams(prompt.getTemplateParameters());
var dto = prompt.toCompletionPostRequestDTO(clientConfig);

return executeRequest(dto);
}
Expand All @@ -106,6 +97,7 @@ public Stream<String> streamChatCompletionDelta(@Nonnull final OrchestrationProm
@Nonnull
protected CompletionPostResponse executeRequest(@Nonnull final CompletionPostRequest request) {
final var client = ApacheHttpClient5Accessor.getHttpClient(destination);
// TODO: update after AiCoreService refactoring
final BasicClassicHttpRequest postRequest =
new HttpPost("/v2/inference/deployments/db1d64d9f06be467/completion");
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
* @param <T> Type of the specific implementation to make a fluent API possible.
*/
public interface OrchestrationConfig<T extends OrchestrationConfig<T>> {
@Nonnull
T instance();

@Nonnull
Option<LLMModuleConfig> getLlmConfig();
Expand Down Expand Up @@ -44,30 +42,4 @@ public interface OrchestrationConfig<T extends OrchestrationConfig<T>> {

@Nonnull
T withOutputContentFilter(@Nonnull final ContentFilter filter);

/**
* Copy the configuration into the given target configuration. The copy is
* <strong>shallow</strong> and does <strong>not override</strong> any existing configuration.
*
* <p>This has two main use cases:
*
* <ol>
* <li>Duplicating a config
* <li>Applying defaults to a config
* </ol>
*
* @param source The source configuration to copy from.
*/
default T copyOrchestrationConfigurationFrom(@Nonnull final OrchestrationConfig<?> source) {
getLlmConfig().orElse(source::getLlmConfig).forEach(this::withLlmConfig);
getTemplate().orElse(source::getTemplate).forEach(this::withTemplate);
getMaskingConfig().orElse(source::getMaskingConfig).forEach(this::withMaskingConfig);
getInputContentFilter()
.orElse(source::getInputContentFilter)
.forEach(this::withInputContentFilter);
getOutputContentFilter()
.orElse(source::getOutputContentFilter)
.forEach(this::withOutputContentFilter);
return instance();
}
}
Loading

0 comments on commit bcc3a1a

Please sign in to comment.